Pytorch学习——断点续训
1、保存断点模型状态
for epoch in range(starting_epoch, parse.epochs):
for step, (x, y) in enumerate(train_loader):
...
if epoch % 1 == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
model_path = "./model/" + "model_{:03d}.pt".format(epoch)
torch.save(checkpoint, model_path)
2、读取状态
checkpoint_path = "./model/model_" + parse.checkpoint + ".pt"
train_state = torch.load(checkpoint_path)
model.load_state_dict(train_state['model_state_dict'])
optimizer.load_state_dict(train_state['optimizer_state_dict'])
starting_epoch = train_state['epoch'] + 1
3、完整的例程
import torch
from torch import optim, nn
import torchvision
from torch.utils.data import DataLoader
from defect import DefectDataset
from resnet import ResNet18
import matplotlib.pyplot as plt
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="defeat detect")
parser.add_argument("echo", help="echo the string you use here")
parser.add_argument("--epochs", default=10, help="train epochs", type=int)
parser.add_argument("--resume", help="increase output verbosity",
action="store_true")
parser.add_argument("--checkpoint", default="0", help="checkpoint num")
return parser.parse_args()
def evalute(model, loader, device):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def train(parse, device, train_loader, val_loader, train_state=None):
model = ResNet18().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
best_acc, best_epoch = 0, 0
global_step = 0
train_loss_list = []
acc_val_list = []
starting_epoch = 0
if parse.resume:
checkpoint_path = "./model/model_" + parse.checkpoint + ".pt"
train_state = torch.load(checkpoint_path)
if train_state is not None:
model.load_state_dict(train_state['model_state_dict'])
optimizer.load_state_dict(train_state['optimizer_state_dict'])
starting_epoch = train_state['epoch'] + 1
for epoch in range(starting_epoch, parse.epochs):
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_step += 1
train_loss_list.append(loss.item())
print('epoch: ', epoch, 'step:', step, 'loss: ', loss.item())
if epoch % 1 == 0:
val_acc = evalute(model, val_loader, device)
acc_val_list.append(val_acc)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
model_path = "./model/" + "model_{:03d}.pt".format(epoch)
torch.save(checkpoint, model_path)
print('best acc:', best_acc, 'best epoch:', best_epoch)
if __name__ == '__main__':
parse = parse_args()
print(parse.echo)
epochs = parse.epochs
batchsz = 8
lr = 1e-3
device = torch.device('cuda:0')
torch.manual_seed(1234)
train_db = DefectDataset('data', 416, mode='train')
val_db = DefectDataset('data', 416, mode='val')
test_db = DefectDataset('data', 416, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=8)
val_loader = DataLoader(train_db, batch_size=batchsz, num_workers=8)
test_loader = DataLoader(train_db, batch_size=batchsz, num_workers=8)
train(parse, device, train_loader, val_loader)