tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战

有了一些tensorflow2.0的基础,然后加上之前对tensorflow也有了解,所以今天用tensorflow2.0做个mnist的分类实战,这也相当于“hello world”吧。

首先看看最后的训练结果图,网络结构图

tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战_第1张图片
tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战_第2张图片
可以看到网络结构其实很简单,就是输入层,然后两层中间层,最后一个输出层用softmax

代码部分

import tensorflow as tf
import numpy as np
import math
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential
from tensorflow.keras.callbacks import TensorBoard
import datetime
import matplotlib.pyplot as plt
plt.rcParams['axes.unicode_minus']=False


#导入数据
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
print(x_train.shape,x_test.shape)

在这里插入图片描述

#预处理---正规化
def normalize(x,y):
    x=tf.cast(x,tf.float32)
    x/=255
    return x,y

#添加一层维度,方便后续扁平化
x_train=tf.expand_dims(x_train,axis=-1)
x_test=tf.expand_dims(x_test,axis=-1)

train_dataset=tf.data.Dataset.from_tensor_slices((x_train,y_train))
test_dataset=tf.data.Dataset.from_tensor_slices((x_test,y_test))
train_dataset=train_dataset.map(normalize)
test_dataset=test_dataset.map(normalize)


#画图
plt.figure(figsize=(10,15))
i=0
for (x_test,y_test) in test_dataset.take(25):
    x_test=x_test.numpy().reshape((28,28))
    plt.subplot(5,5,i+1)
    plt.grid(False)
    plt.xticks([])
    plt.imshow(x_test,cmap=plt.cm.binary)
    plt.xlabel([y_test.numpy()],fontsize=10)
    i+=1
plt.show()

tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战_第3张图片

#开始定义模型
model=tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28,1)),
    tf.keras.layers.Dense(128,activation=tf.nn.relu),
    tf.keras.layers.Dense(64,activation=tf.nn.relu),
    tf.keras.layers.Dense(10,activation=tf.nn.softmax)
])

model.summary()

tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战_第4张图片

# 模型编译
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

#开始训练
batch_size=32
train_dataset=train_dataset.repeat().shuffle(60000).batch(batch_size)
test_dataset=test_dataset.batch(batch_size)
#为tensorboard可视化保存数据
tensorboard_callback=tf.keras.callbacks.TensorBoard(histogram_freq=1)
model.fit(train_dataset,epochs=5,steps_per_epoch=math.ceil(60000/batch_size),callbacks=[tensorboard_callback])

tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战_第5张图片
这里本来想自己定义保存文件的路径的,但是运行就出错,没找到修改的办法只能删除自己的路径,还是用默认的路径,希望大佬们能指导一下这里该怎么修改

#模型评估
test_loss,test_accuracy=model.evaluate(test_dataset,steps=math.ceil(10000/32))
print('Accuracy on test_dataset',test_accuracy)

在这里插入图片描述

#模型预测
predictions=model.predict(test_dataset)

#查看预测结果
def plot_test(i,predictions_array,true_labels,images):
    predic,label,img=predictions_array[i],true_labels[i],images[i]
    plt.grid(False)
    plt.xticks([])
    plt.imshow(img[...,0],cmap=plt.cm.binary)
    predic_label=np.argmax(predic)
    if(predic_label==label):
        color='green'
    else:
        color='red'
    plt.xlabel("预测标签为:{},概率:{:2.0f}% (真实标签:{})".format(predic_label,100*np.max(predic),label),color=color)
    
    
def plot_value(i,predictions_array,true_label):
    predic,label=predictions_array[i],true_label[i]
    plt.grid(False)
    plt.xticks([])
    thisplot=plt.bar(range(10),predic,color='#777777')
    plt.ylim([0,1])
    predic_label=np.argmax(predic)
    thisplot[predic_label].set_color('blue')
    thisplot[label].set_color('green')
    
    
rows,cols=5,3
num_images=rows*cols
for test_images,test_labels in test_dataset.take(1):
    test_images=test_images.numpy()
    test_labels=test_labels.numpy()
    
plt.figure(figsize=(2*2*cols,2*rows))
for i in range(num_images):
    plt.subplot(rows,2*cols,2*i+1)
    plot_test(i,predictions,test_labels,test_images)
    plt.subplot(rows,2*cols,2*i+2)
    plot_value(i,predictions,test_labels)

tensorflow2.0---2.利用tensorflow2.0进行mnist分类实战_第6张图片

总结:

总的来说,tensorflow2.0的使用还是比较简单的,没有了之前1.X的session感觉更方便,如果对1.X有的api有不懂的地方可以随时去https://www.w3cschool.cn/tensorflow_python/查找需要的帮助;
对于2.0也可以去https://tf.wiki学习。

你可能感兴趣的:(tensorflow2.0)