Pytorch加载模型

一、假设我们只保存了模型的参数(model.state_dict())到文件名为modelparameters.pth, model = Net()

1. cpu -> cpu或者gpu -> gpu:

checkpoint = torch.load('modelparameters.pth')

model.load_state_dict(checkpoint)

2. cpu -> gpu 0

checkpoint = torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(0))

model.load_state_dict(checkpoint)

3. gpu 0 -> gpu 1

checkpoint = torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})

model.load_state_dict(checkpoint)
4. gpu -> cpu
checkpoint = torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)

model.load_state_dict(checkpoint)
二、pytorch完整加载模型实例
# 模型是GPU训练的
def load_models(model_param_path):
    torch.manual_seed(14)
    np.random.seed(14)
    random.seed(14)
    if torch.cuda.is_available():
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(14)

    model = FeatherNetB()
    model = torch.nn.DataParallel(model)
    
    device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else "cpu")
    model.to(device)

    print("=> loading checkpoint '{}'".format(param_path))
    if torch.cuda.is_available():
        #gpu测试
        checkpoint = torch.load(model_param_path)
    else:
        #cpu测试
        checkpoint = torch.load(model_param_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    
    #测试的时候必须加上model.eval(),训练就用model.train()
    model.eval()             
       
    return model

 

你可能感兴趣的:(Pytorch加载模型)