利用tensorflow2实现手写数字识别

#首先介绍一下fit函数的参数列表中各个参数的意义
'''
fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose=1,
    validation_data=None,
    shuffle=True,
    validation_steps=None
)
fit的作用是在固定的epochs次数上训练模型,也就是在数据集上迭代

x 可以是numpy数组,tensorflow张量或者tensorflow.data类型的数据集

y y的类型与x要一致,x是numpy数组,y的类型也必须是numpy数组,x是tensorflow张量,y也必须是tensorflow的张量类型,y可以不提供的前提是x是
tensorflow.data类型的数据集,因为这种情况下标签会从数据集中得到

batch_size 整数类型或者是None,不要指明batch_size的大小当输入的数据是数据集或者生成器类型时(因为它们会生成自己的batch_size)
除非输入的数据是array或者tensors,那么可以指定batch_size,batch_size的意义就是每次梯度更新取样多少个数据(SGD等同于batch_size=1,GD等同于batch_size=总的数据的次数,最常用的mini-batch的梯度更新,也就是1

import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds

mnist_dataset,mnist_info=tfds.load(name='mnist',with_info=True,as_supervised=True)
print(mnist_info)
'''
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),

'''

assert len(mnist_dataset)==2 and type(mnist_dataset)==dict #mnist_dataset包含训练集和测试集合
mnist_train,mnist_test=mnist_dataset['train'],mnist_dataset['test']#train,test里面是包含有数据和标签两部分的
train_samples=mnist_info.splits['train'].num_examples
test_samples=mnist_info.splits['test'].num_examples
valid_samples=int(0.1*train_samples)#取训练集的十分之一做验证集
print(train_samples,test_samples,valid_samples)


def scaled_data(image,label):
    image=tf.cast(image,dtype=tf.float32)
    image/=255.#将数据集中所有数据值从0~255转到0~1之间
    return image,label

scaled_train_data=mnist_train.map(scaled_data)
scaled_test_data=mnist_test.map(scaled_data)

shuffle_size=10000
shuffle_train_data=scaled_train_data.shuffle(shuffle_size)

valid_data=shuffle_train_data.take(valid_samples)
train_data=shuffle_train_data.skip(valid_samples)

batch_size=100
batch_train_data=train_data.batch(batch_size)
valid_data=valid_data.batch(valid_samples)
test_data=scaled_test_data.batch(test_samples)
#注意到验证集和测试集由于不需要训练,所以不需要划分batch_size,直接将所有数据送进神经网络前向传播计算出loss以及accuracy即可,所以将batch_size的大小
#设置为数据的个数

input_dim=784
output_dim=10
hidden_dim1=50
hidden_dim2=64


def train_model1():
    model=tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28,28,1)),
        tf.keras.layers.Dense(hidden_dim1,activation='relu'),
        tf.keras.layers.Dense(hidden_dim2,activation='sigmoid'),
        tf.keras.layers.Dense(output_dim,activation='softmax')
    
        ])
    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
    model.fit(batch_train_data,epochs=5,verbose=2,validation_data=valid_data,validation_steps=1)
    test_loss,test_accuracy=model.evaluate(test_data)
    print(test_loss,test_accuracy)
    print("*"*100)

def train_model2():
    model=tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28,28,1)),
        tf.keras.layers.Dense(hidden_dim1,activation='relu'),
        tf.keras.layers.Dense(hidden_dim2,activation='sigmoid'),
        tf.keras.layers.Dense(output_dim,activation='softmax')
    
        ])
    model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
    for epoch in range(5):
        for batch_data,batch_label in batch_train_data:
            batch_data=batch_data.numpy()
            batch_label=batch_label.numpy()
            assert batch_data.shape==(batch_size,28,28,1) and batch_label.shape==(batch_size,)
            model.fit(x=batch_data,y=batch_label,batch_size=batch_size,verbose=0)
            #如果不想每batch_size就输出当前的loss和accuracy,那就把verbose设置为0
            
        for batch_valid_data,batch_valid_label in valid_data:
            batch_valid_data=batch_valid_data.numpy()
            batch_valid_label=batch_valid_label.numpy()
            assert batch_valid_data.shape==(valid_samples,28,28,1) and batch_valid_label.shape==(valid_samples,)
            model.fit(x=batch_data,y=batch_label,batch_size=batch_size,verbose=2,epochs=1,validation_data=(batch_valid_data,batch_valid_label))
    #当传入的数据是numpy数组时,需要指明batch_size,不指定默认是32,在每次epoch的结束时候,用验证集来验证模型的效果
    print(model.evaluate(test_data))

train_model2()
train_model1()

你可能感兴趣的:(利用tensorflow2实现手写数字识别)