这一节将比较仔细的讲述一个比较重要的知识,就是模型的保存和加载。我们都知道深度学习模型在训练一些大数据集的时候往往需要很长的时间,如果这时候突然断网了或者停电了,那训练不就GG了吗?如果我想再次使用这个模型,难道要重新来过?当然不。在Tensorflow2.0中官方提供了多种模型保存和加载的方式,我们可以训练一定次数进行保存方便下次打开代码时接着进行训练,听起来是不是十分方便呢?接下来让我们来看一看具体的代码实现部分,我们采用最简单的MNIST数据集来演示模型保存和加载的过程。
1.导入相关库。
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
2.读取Mnist数据集。
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()
3.选取部分数据集进行训练快速训练(这部分不是重点,因此简单配置一下)。在这里训练图像需要reshape成一维值,因为我们前期并没有经过卷积,而是直接传入MLP(多层感知机)进行训练。
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1,28*28)/255.0 #转一维并归一化
test_images = test_images[:1000].reshape(-1,28*28)/255.0 #转一维并归一化
4.搭建简单的模型。这里蛮提一下sparse_categorical_crossentropy损失函数和categorcial_crossentropy函数的差别,如果你的标签label是纯数字未经过one-hot编码则用前者,如果经过one-hot编码则用后者。
def creat_model():
model = tf.keras.models.Sequential([
layers.Dense(512,input_dim=784,activation='relu'),
layers.Dropout(0.2), #防过拟合
layers.Dense(10,activation='softmax')
])
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
return model
model = creat_model()
5.打印模型概要。
model.summary()
1.以 checkpoints 形式保存,我们可以用tf.keras.callbacks.ModelCheckpoint这个函数允许模型在训练的过程中和训练结束时回调保存的模型。我们首先要定义模型文件要保存的路径,接着设置ModelCheckpoint一些相关参数,最后将将其作为回调传入模型进行训练。
checkpoint_path = 'training_1/cp.ckpt' #模型保存路径和名称
checkpoint_dir = os.path.dirname(checkpoint_path) #取模型保存文件的目录
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1) #只保存训练参数
model.fit(train_images,train_labels,epochs=10,validation_split=0.2,callbacks=[cp_callback])
2.重新建立一个未训练的模型,测试一下测试集的结果。
model = creat_model()
loss,acc = model.evaluate(test_images,test_labels,verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
可以看到模型未训练时,分类效果很差。
3.将刚才训练的模型参数加载至新建立的未训练模型中。这里有一点注意的是,如果我们仅仅只是加载模型的参数,那我们这个模型必须要和保存参数的那个模型结构一模一样才可以。
model.load_weights(checkpoint_path) #加载训练好的模型参数
loss,acc = model.evaluate(test_images,test_labels,verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
模型测试的结果正常,模型参数加载成功。
4.这里我们还可以在函数中加入period设置checkpoint的频数,规定它训练多少次才保存一次参数,这里设置为5代表每5次epoch保存一次checkpoint文件。
5.我们还可以用latest_checkpoint函数得到最新保存的那个模型参数文件,并加载看看测试结果。
latest = tf.train.latest_checkpoint(checkpoint_dir)
model = creat_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
从结果可以看出模型加载成功。
1.前面讲的方法是基于模型训练回调自动保存模型参数文件,自然的我们也可以通过手动的方式进行保存和加载,直接使用save_weights和load_weights函数即可。
model.save_weights('./checkpoints/my_checkpoint') #手动保存模型参数
model = creat_model()
model.load_weights('./checkpoints/my_checkpoint') #模型参数加载
loss,acc = model.evaluate(test_images,test_labels,verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
从结果可以看出模型参数加载成功。
1.我们还可以通过save和load_model直接保存及加载h5形式的模型文件(如果没有安装h5的库,可以参照前几节内容进行安装)。
model = creat_model()
model.fit(train_images,test_labels,epochs=20)
model.save('my_model.h5')
new_model = keras.models.load_model('my_model.h5')
以上就是这节内容,知道了怎么保存和加载我们的模型之后,之后我们便可以更方便地训练模型了。谢谢大家的观看和支持!