以前用Keras用惯了,fit和fit_generator真的太好使了,模型断电保存搞个checkpoint回调函数就行了。近期使用pytorch进行训练,苦于没有类似的回调函数,写完网络进行训练的时候总不能每次都从头开始训练,于是乎就学了一下pytorch的模型相关操作。
argparse是一个Python模块:命令行选项、参数和子命令解析器。
主要有三个步骤:
如下:
parser = argparse.ArgumentParser()
parser.add_argument('--train-file', type=str, default='pre/91-image_x2.h5')
parser.add_argument('--eval-file', type=str, default='pre/Set5_x2.h5')
parser.add_argument('--outputs-dir', type=str, default='output/')
parser.add_argument('--weights-file', type=str, default='weight/')
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--batch-size', type=int, default=1024)
parser.add_argument('--num-epochs', type=int, default=1000)
parser.add_argument('--num-workers', type=int, default=16)
parser.add_argument('--seed', type=int, default=123)
args = parser.parse_args()
使用参数的时候可以通过args.xxx进行调用就行了,这样的好处就是方便统一管理。 如果用编译器运行就得给赋默认值,如果命令行运行,在运行的时候命令后面给出参数就行,有默认值的会进行覆盖。
包含网络模型的实例化,损失函数,优化器等一系列操作。
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
model = FSRCNN(scale_factor=args.scale).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
{'params': model.first_part.parameters()},
{'params': model.mid_part.parameters()},
{'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
], lr=args.lr)
加载上次训练生成的参数文件,通过update操作进行更新,并加载到现有模型中进行训练,这个也就是预训练参数,还要去掉参数中多余的k,v对。
model_dict = model.state_dict()
pre_dict = torch.load('训练参数文件.pth')
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
model_dict.update(pre_dict)
model.load_state_dict(model_dict)
torch.save就可以进行模型的保存,里面传入的参数不一样保存方式就不一样。
torch.save(model.state_dict(), path)
eg:
torch.save(model.state_dict(),
os.path.join(args.outputs_dir, 'epoch_{}_psnr{:.2f}.pth'.format(best_epoch, best_psnr)))
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)
这种方式加载的时候:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])
torch.save(model, path)
这种类型的错误是因为版本原因导致的,在pytorch1.6及以上,pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。解决办法要么把pytorch升级到1.6及以上版本,要么就在保存的时候不适用zip格式进行保存。
state_dict = torch.load("xxx.pth")
torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False)
报这个错就是因为没有去除掉参数文件中多余的键值对,去掉就行。
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
pytorch1.6以下版本还不报错,1.6以上就报错了,原因是在加载的时候,没有GPU,只有CPU,必须在load中指明使用CPU才可以。
model = torch.load(model_path, map_location='cpu')