大家好,欢迎来到IT知识分享网。
先上代码~
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from sklearn.mixture import GaussianMixture
#产生实验数据
from sklearn.datasets.samples_generator import make_blobs # make_blobs是生成聚类使用的数据集
X, labels_true = make_blobs(n_samples = 400, n_features = 2, centers = 4, cluster_std = 0.60, random_state = 0)
# 模型训练
gmm = GaussianMixture(n_components=4).fit(X)
labels = gmm.predict(X)
# 作图
plt.figure(1)
plt.scatter(X[:, 0], X[:, 1], c=labels_true, s=40, cmap='viridis')
plt.figure(2)
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis')
plt.show()
一些解释:
1. 去掉 PyCharm 的警告提示:
import warnings
warnings.filterwarings('ignore')
2. 生成聚类数据集:
from sklearn.datasets.samples_generator import make.blobs
samples, labels = make.blobs(n_samples=100, n_features=2, centers=3, cluster_std=1.0, center_box=(-10.0, 10.0), shuffle=True, random_state=0)
# 输入参数:
# n_samples:生成数据集的点数
# n_features:数据集中数据的维度
# centers:数据聚类的标签个数(类别数)
# cluster_std:数据的标准差
# center_box:中心确定之后的数据边界
# shuffle:洗乱
# random_state:随机生成器的种子
# 输出参数:
# samples:产生的数据集
# labels:数据集对应的标签
3. 训练GMM模型:
gmm = GaussianMixture(n_components=1).fit(samples)
lables_gmm = gmm.predict(samples)
# 一些解释:
# n_components:混合高斯模型个数
# .fit(samples):基于 samples 样本使用 EM 算法训练 GMM 模型
# .predict(samples):使用训练得到的 GMM 模型预测样本数据标签
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/25017.html