基于pytorch框架使用多gpu训练时,如何有效降低显存

基于pytorch框架使用多gpu训练时,如何有效降低显存

文章目录

  • 基于pytorch框架使用多gpu训练时,如何有效降低显存
    • 1. 问题阐述
    • 2. 解决方案
      • 2.1 周期性使用`torch.cuda.empty_cache()`函数
      • 2.2 将loss的计算写入网络的`forward()`函数中
      • 2.3 采用`del()`函数及时删除创建的中间变量或无用的临时变量
      • 2.4 使用relu激活函数时,记得将inplace参数设为True
      • 2.5 采用累积梯度更新的方式
    • 3. 总结

1. 问题阐述

​ 我们在基于pytorch框架采用多gpu进行大数据的训练时,由于自己的batch_size设计的偏大或者其他原因,经常会在程序运行过程中遇到RuntimeError,使得程序中断。在RuntimeError中,Out of Memory``错误出现的最多,其次是 CUDNN_STATUS_EXECUTION_FAILED错误。这两种错误发生的原因都是因为显存不够。

2. 解决方案

为了优化显存的使用,在网上查阅的资料和自己亲身实践的基础上,总结出以下几点

2.1 周期性使用torch.cuda.empty_cache()函数

  1. 使用pytorch在进行神经网络训练时,会产生很多临时变量,随着训练的进行,临时变量将会对训练效率产生一点影响,所以定期采用empty_cache()不仅可以及时清除训练过程中的缓存,释放更多的显存空间,而且提高训练效率。
  2. model.eval()之前使用empty_cache(),可以及时清理训练阶段的临时数据
  3. 在保存网络模型前使用empty_cache(),可以减少保存的模型大小。

总之,empty_cache()函数是个非常有用的函数,对于优化显存使用有着很大的帮助

2.2 将loss的计算写入网络的forward()函数中

​ 这里要先说明一下pytorch框架所采用的多gpu并行计算的原理。pytoch会先根据你的设定选用第一块gpu为__主gpu来存放模型__,而其他的gpu只用来接受主gpu分发的数据进行计算。并且__并行计算只发生在神经网络的前馈计算__中,反向梯度传播是在主gpu上进行的。所以,在采用多gpu训练大规模网络的时候,经常会出现负载不均衡的现象。

​ 为了缓解该现象,我们可以将loss的计算写入forward()函数中。因为loss的计算是比较消耗显存的。从博主的经验来看,loss的计算对于显存的消耗不大,但是建议将loss的计算写入forward()函数。

2.3 采用del()函数及时删除创建的中间变量或无用的临时变量

​ 为了减少中间变量或无用的临时变量对于缓存的占用,采用del()函数对其进行删除是种不错的选择。

2.4 使用relu激活函数时,记得将inplace参数设为True

output = f.relu(x, inplace=True)

​ 将inplace参数设为True,将不会产生中间变量,从而减少显存的使用,但是从博主的经验来看,影响不大,但建议最好加上

2.5 采用累积梯度更新的方式

​ 由于显存的限制,在训练大规模网络模型的时候,batch_size的一般设的比较小,但是batch_size过小会影响网络的收敛速度,所以可以采用累积梯度更新的方式,实现大batch_size一样的效果。

​ 累积梯度更新的原理,计算得到多个batch的梯度值后对网络进行更新。

​ 累积梯度更新的实现:

loss = model(batch)
loss = loss / gradient_accumulation_steps
loss.backward()
if step % gradient_accumulation_steps == 0:
    optimizer.step()
    optimizer.zero_grad()

3. 总结

​ 文中提到的2.1~2.4都是不错的优化显存使用的方法,2.5提到的累计梯度更新的方法对于训练大规模网络有比较好的作用。

你可能感兴趣的:(pytorch)