torch.save()
将序列化对象保存到磁盘。序列化就是把数据变成可存储或可传输的过程的,只有序列化后的数据才可以写入到磁盘或者通过网。此函数使用Python的 pickle
进行序列化。使用此函数可以保存各种对象的模型、张量和字典。
pickle
的一些具体使用函数
pickle.dump() #将任意对象转化成bytes,并写入文件中。
import pickle
# 使用pickle模块将数据对象保存到文件
data = {'a': [1, 2.0, 3, 4+6j],
'b': ('string', u'Unicode string'),
'c': None}
file = open('data.pkl','wb') #将任意对象转化成bytes,并写入文件中
pickle.dump(data, file)
file.close()
torch.load()
使用Python的pickle
将被序列化的对象文件反序列化到内存。此函数还可方便设备将数据加载进来(请看 Saving & Loading Model Across Devices).
pickle
的一些具体使用函数
pickle.load() #从文件中反序列出对象
import pickle
file = open('data.pkl','rb')
data = pickle.load(file)
file.close()
print(data)
#{'a': [1, 2.0, 3, (4+6j)], 'b': ('string', 'Unicode string'), 'c': None}
在PyTorch中,
model.parameters()
访问torch.nn.Module
模型的可学习参数(即权重和偏置)包含在模型的 parameters 中。torch.nn.Module
的state_dict
是一个Python字典对象,它将具有可学习参数的层(卷积层、线性层等)映射到其参数张量。torch.optim
有一个state_dict
,包含关于优化器状态的信息以及使用的超参数。torch.nn.Module.load_state_dict()
使用反序列化的state_dict
加载模型的参数字典。import torch.nn as nn
import torch.functional as F
import torch.optim as optim
# 定义模型
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 初始化 model
model = TheModelClass()
# 初始化 optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 输出 model 的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Model's state_dict:
# conv1.weight torch.Size([6, 3, 5, 5])
# conv1.bias torch.Size([6])
# conv2.weight torch.Size([16, 6, 5, 5])
# conv2.bias torch.Size([16])
# fc1.weight torch.Size([120, 400])
# fc1.bias torch.Size([120])
# fc2.weight torch.Size([84, 120])
# fc2.bias torch.Size([84])
# fc3.weight torch.Size([10, 84])
# fc3.bias torch.Size([10])
# 输出 optimizer 的 state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
#Optimizer's state_dict:
# state {}
# param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
在保存模型进行推理时,只需保存经过训练的模型的学习参数state_dict
即可。使用 torch.save()
函数保存模型的state_dict
将为以后恢复模型提供最大的灵活性,这就是为什么推荐使用它来保存模型。
常见的PyTorch约定是使用 .pt
或.pth
文件扩展名保存模型。
模型参数的保存
torch.save(model.state_dict(), file_path)
模型参数的加载
#device = torch.device("cuda")
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(file_path))
#model.to(device)
PyTorch save and load models 保存和加载模型