import numpy as np
import matplotlib.pyplot as plt
# 1.散点输入
class1_points = np.array([[1.9, 1.2],
[1.5, 2.1],
[1.9, 0.5],
[1.5, 0.9],
[0.9, 1.2],
[1.1, 1.7],
[1.4, 1.1]])
class2_points = np.array([[3.2, 3.2],
[3.7, 2.9],
[3.2, 2.6],
[1.7, 3.3],
[3.4, 2.6],
[4.1, 2.3],
[3.0, 2.9]])
class3_points = np.array([[3.3, 1.2],
[3.8, 0.9],
[3.3, 0.6],
[2.8, 1.3],
[3.5, 0.6],
[4.2, 0.3],
[3.1, 0.9]])
#合并数据集 创造标签
X=np.concatenate((class1_points,class2_points,class3_points),axis=0)
Y=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class1_points)),np.ones(len(class1_points))+1),axis=0)
print(Y)
# 2.计算先验
# 2.1 每一个类别的数据在数据集中的比例
prior_prob=[np.sum(Y==0)/len(Y),np.sum(Y==1)/len(Y),np.sum(Y==2)/len(Y)]
#3.计算高斯分布的概率密度函数
# 求解包括蓝色点数据的均值 和红色点数据的均值
class_μ=[np.mean(X[Y==0],axis=0),np.mean(X[Y==1],axis=0),np.mean(X[Y==2],axis=0)]
#求协方差矩阵
class_cov=[np.cov(X[Y==0],rowvar=False),np.cov(X[Y==1],rowvar=False),np.cov(X[Y==2],rowvar=False)]
#使用for循环进行求解高斯概率密度函数
#获取新的坐标点(x1,x2)
def pdf(x,mean,cov):
#1.获取均值向量的长度,即特征的数量
n=len(mean)
#2计算系数
# numpy.linalg.det()函数计算输入矩阵的行列式
coff=1/(2*np.pi)**(n/2)*np.sqrt(np.linalg.det(cov))
#3计算指数部分
# np.dot()计算两个一维数组的内积
exponent=np.exp(-(1/2)*np.dot(np.dot((x-mean).T,np.linalg.inv(cov)),(x-mean)))
return coff*exponent
#获得xy轴上的足够多的坐标点
xx,yy=np.meshgrid(np.arange(0,5,0.05),np.arange(0,4,0.05))
#拿到预测点
# np.c_沿着矩阵的第二个轴拼接
grid_points=np.c_[xx.ravel(),yy.ravel()]
# point=np.array([3,1.5])
#存储后验结果
grid_label=[]
#预测网格点
for point in grid_points:
poster_prob = []
for i in range(3):
#使用概率密度函数求条件概率
likelihood=pdf(point,class_μ[i],class_cov[i])
#计算后验概率=先验*条件概率
poster_prob.append(prior_prob[i]*likelihood)
pre_class=np.argmax(poster_prob)
grid_label.append(pre_class)
#绘制散点图
plt.scatter(class1_points[:,0],class1_points[:,1],c="blue",label="class 1")
plt.scatter(class2_points[:,0],class2_points[:,1],c="red",label="class 2")
plt.scatter(class3_points[:,0],class3_points[:,1],c="yellow",label="class 3")
# plt.scatter(point[0],point[1],c="green",label='point')
#添加图例
plt.legend()
#显示决策边界
#预测标签和xx形状一致
#列表先转成数组
grid_label=np.array(grid_label)
pre_grid_label=grid_label.reshape(xx.shape)
#等高线绘制
contour=plt.contour(xx,yy,pre_grid_label,colors='black')
plt.show()