Pytorch使用DDP加载预训练权重时出现占用显存的多余进程

感谢知乎作者 https://www.zhihu.com/question/67209417/answer/866488638
在使用DDP进行单机多卡分布式训练时,出现了在加载预训练权重时显存不够的现象,但是相同的代码单机单卡运行并不会出现问题,后来发现是在多卡训练时,额外出现了3个进程同时占用了0卡的部分显存导致的,而这3个进程正是另外3张卡load进来的数据,默认这些数据被放在了0卡上。解决的方法是把load进来的数据放在cpu(也就是内存)里。

# 原来代码,load进的数据放在gpu里
# pretrain_weight = torch.load(path)['model']
# 应该改成
pretrain_weight = torch.load(path, map_location=torch.device('cpu'))['model']
model.load_state_dict(pretrain_weight)

你可能感兴趣的:(环境配置,pytorch,python,深度学习)