tensorflow2.0-手写数字识别案例

import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

# 载入数据
(Xtrain, Ytrain), (Xtest, Ytest) = mnist.load_data()

# 对数据进行标准化处理
scaler = StandardScaler()
Xtrain_scaled = scaler.fit_transform(Xtrain.reshape(-1, 28 * 28)).reshape(-1, 28, 28)
Xtest_scaled = scaler.transform(Xtest.reshape(-1, 28 * 28)).reshape(-1, 28, 28)

# (60000, 28, 28)
# (10000, 28, 28)
# print(Xtrain_scaled.shape)
# print(Xtest_scaled.shape)

# 建立模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))

model.summary()

# 编译模型
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['acc'], optimizer='adam')

# 训练模型
history = model.fit(Xtrain, Ytrain, epochs=15)
# dict_keys(['loss', 'acc'])
# print(history.history.keys())

# print(history.epoch)

# 绘制图像,查看loss与acc
plt.plot(history.epoch, history.history['loss'], color='red', label='loss')
plt.plot(history.epoch, history.history['acc'], color='blue', label='acc')
plt.legend()
plt.show()

tensorflow2.0-手写数字识别案例_第1张图片
tensorflow2.0-手写数字识别案例_第2张图片

你可能感兴趣的:(tensorflow,神经网络,深度学习,机器学习,图像识别)