本篇基于Tensorflow 1.0进行神经网络创建和MNIST手写体识别训练
Tensorflow 1.0 的安装和通过机器学习实现MNIST手写体识别
Tensorflow实现MNIST手写体识别
前置条件
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
from time import time
全连接层函数定义
# 定义全连接层函数,fcn(全连接网络)
def fcn_layer(inputs, # 输入数据
input_dim, # 输入神经元数量
output_dim, # 输出神经元数量
activation=None): # 激活函数
# 以截断正太分布的随机数初始化
W = tf.Variable(tf.truncated_normal([input_dim, output_dim], stddev=0.1))
# 以0初始化b
b = tf.Variable(tf.zeros([output_dim]))
# 建立表达式:input*W+b
XWB = tf.matmul(inputs, W) + b
# 默认不使用激活函数
if activation is None:
outputs = XWB
# 若传入激活函数,则用其对输出结果进行变换
else:
outputs = activation(XWB)
return outputs
MNIST训练集读取
mnist = input_data.read_data_sets('data/MNIST/', one_hot=True)
变量定义
# 定义标签数据占位符(输入层)
x = tf.placeholder(tf.float32, [None, 784], name='x')
image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
tf.summary.image('input', image_shaped_input, 10)
y = tf.placeholder(tf.float32, [None, 10], name='y')
本次主要构建两层隐藏层神经网络
# 输出神经元个数
H1_NN = 256
H2_NN = 64
H3_NN = 32
# 构建输入层
h1 = fcn_layer(inputs=x,
input_dim=784,
output_dim=H1_NN,
activation=tf.nn.relu)
# 构建隐藏层1
h2 = fcn_layer(inputs=h1,
input_dim=H1_NN,
output_dim=H2_NN,
activation=tf.nn.relu)
# 构建隐藏层2
h3 = fcn_layer(inputs=h2,
input_dim=H2_NN,
output_dim=H3_NN,
activation=tf.nn.relu)
# 构建输出层
forward = fcn_layer(inputs=h3,
input_dim=H3_NN,
output_dim=10,
activation=None)
tf.summary.histogram('forward', forward)
pred = tf.nn.softmax(forward)
定义损失函数
# 定义损失函数,交叉熵
loss_function = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=forward, labels=y)
)
tf.summary.scalar('loss', loss_function)
定义优化器
# 选择优化器
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)
定义准确率
# 定义准确率
correst_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))
accuracy = tf.reduce_mean(tf.cast(correst_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
设置训练参数
# 设置训练参数
train_epochs = 40
batch_size = 50
total_batch = int(mnist.train.num_examples / batch_size)
display_step = 1
learning_rate = 0.01
训练模型存储
# 存储模型的粒度
save_step = 5
# 存储路径
ckep_dir = './ckpt_dir/'
if not os.path.exists(ckep_dir):
os.makedirs(ckep_dir)
# 模型保存
saver = tf.train.Saver()
模型恢复
# 模型恢复
ckpt = tf.train.get_checkpoint_state(ckep_dir)
if ckpt and ckpt.model_checkpoint_path:
# 从已保存的模型中读取参数
saver.restore(sess, ckpt.model_checkpoint_path)
print("Restore model from " + ckpt.model_checkpoint_path)
定义会话
sess = tf.Session()
sess.run(tf.global_variables_initializer())
写入模型数据流图
merged_summary_op = tf.summary.merge_all()
# 创建写入符,合并 所有的summary
writer = tf.summary.FileWriter('log/', sess.graph)
模型训练及模型保存
for epoch in range(train_epochs):
for batch in range(total_batch):
# 读取批次数据
xs, ys = mnist.train.next_batch(batch_size)
# 执行批次数据
sess.run(optimizer, feed_dict={x: xs, y: ys})
summary_str = sess.run(merged_summary_op, feed_dict={x: xs, y: ys})
writer.add_summary(summary_str, epoch)
# total_batch个批次训练完成后,使用验证数据计算误差和准确率
loss, acc = sess.run([loss_function, accuracy],
feed_dict={x: mnist.validation.images,
y: mnist.validation.labels})
if (epoch + 1) % display_step == 0:
print("Train Epoch:", '%02d' % (epoch + 1), 'Loss:{:.9f}'.format(loss), 'Accuracy:{:.4f}'.format(acc))
# 模型保存
if epoch % save_step == 0:
saver.save(sess, os.path.join(ckep_dir,
'mnist_h256_model_{:06d}.ckpt'.format(epoch + 1))) # 模型存储
print('mnist_h256_model_{:06d}.ckpt saved'.format(epoch + 1))
saver.save(sess, os.path.join(ckep_dir, 'mnist_h256_model.ckpt'))
print('Model saved!')
模型评估
# 评估模型
acc_test = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Test Accuracy:", acc_test)
该过程中创建了模型日志,将数据流图写入日志,需要通过Tensorboard进行打开查看
# ~为绝对路径,通过tensorboard打开存放日志路径中的日志 # 在dos命令窗口中完成该操作 tensorboard --logdir=~/log # 查看 localhost:6006 # 指定端口 tensorboard --logdir=~/log --port 8888
常见Tensorflow.summary API
- tf.summary.FileWriter.add_summary()
- 将摘要添加到事件文件
- tf.summary.FileWriter.add_event()
- 向事件文件添加一个事件
- tf.summary.FileWriter.add_graph()
- 向事件文件添加一个图
- tf.summary.FileWriter.get_logdir()
- 获取事件文件的路径
- tf.summary.FileWriter.flush()
- 将所有的时间都写入磁盘
- tf.summary.FileWriter.close()
- 将事件写入磁盘,并关闭文件操作符
- tf.summary.scalar()
- 输出包含单个标量值的摘要
- tf.summary.histogram()
- 输出包含直方图的摘要
- tf.summary.audio()
- 输出包含音频的摘要
- tf.summary.image()
- 输出包含图片的摘要
- tf.summary.merge()
- 合并摘要,包含所有输入摘要的值