bp神经网络利用sklearn自带mnist手写数据集实现对数字的识别
首先导入数据
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
digits = load_digits()
x_data = digits.data
y_data = digits.target
##print(x_data.shape)
##print(y_data.shape)
y=digits.data[1].reshape([8, 8])
拆分训练集和测试集
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data)
构建模型
#构建模型,2个隐藏层,第一个隐藏层有100个神经元,第2隐藏层50个神经元,训练500周期
mlp = MLPClassifier(hidden_layer_sizes=(100,50), max_iter=500)
训练fit
mlp.fit(x_train,y_train)
测试集展示
#测试集准确率的评估
predictions = mlp.predict(x_test)
for i in range(1, 21):
plt.subplot(4,5, i) #划分成2行5列
plt.imshow(x_test[i - 1].reshape([8, 8]), cmap=plt.cm.gray_r)
plt.text(60, 3, str(predictions[i-1])) #在图片的任意位置添加文本
plt.xticks([]) #认为设置坐标轴显示的刻度值
plt.yticks([])
plt.rcParams['savefig.dpi'] = 128 #图片像素
plt.rcParams['figure.dpi'] = 128 #分辨率
plt.subplots_adjust(bottom=0.10,top=1.5)
plt.show()
效果如下:
完整代码:
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
digits = load_digits()
x_data = digits.data
y_data = digits.target
##print(x_data.shape)
##print(y_data.shape)
y=digits.data[1].reshape([8, 8])
#数据拆分
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data)
#构建模型,2个隐藏层,第一个隐藏层有100个神经元,第2隐藏层50个神经元,训练500周期
mlp = MLPClassifier(hidden_layer_sizes=(100,50), max_iter=500)
mlp.fit(x_train,y_train)
#测试集准确率的评估
predictions = mlp.predict(x_test)
for i in range(1, 21):
plt.subplot(4,5, i) #划分成2行5列
plt.imshow(x_test[i - 1].reshape([8, 8]), cmap=plt.cm.gray_r)
plt.text(60, 3, str(predictions[i-1])) #在图片的任意位置添加文本
plt.xticks([]) #认为设置坐标轴显示的刻度值
plt.yticks([])
plt.rcParams['savefig.dpi'] = 128 #图片像素
plt.rcParams['figure.dpi'] = 128 #分辨率
plt.subplots_adjust(bottom=0.10,top=1.5)
plt.show()
##print(classification_report(y_test, predictions))