Pytorch使用loss求和踩坑记录

Pytorch使用loss.item()踩坑记录

问题描述

使用Pytorch框架进行模型性能测试时,免不了需要使用sum_loss += loss对模型每个一个迭代的损失进行求和,最终计算整个测试集上的平均损失。这时我发现我的代码跑着跑着…显存就崩了。不对,压根就还没跑呢。每次一到测试的阶段代码就会报错说我的显存溢出。我都把batch_size改为1了还是会出现这种问题,问题是之前还是好好地,我只是把模型评价指标的求解过程改动了一下。

问题解决

为啥会出现显卡显存溢出的情况呢?原来loss属于Variable类型,当我们使用加和操作时,Pytorch框架会默认我们的计算图构建还没有结束,还在继续构建计算图,因此我们每加和一次loss,就会在原来模型的基础上构建一个新的模型,因此迭代不了几次,显卡内存就会爆满,出现显存溢出的情况。解决这种情况很简单,只需要在加和loss的时候使用loss.item()就行了。这样loss叠加只是在数值上进行叠加了,而不是在loss这个Varuable上进行堆叠了,因此就不会出现显存爆炸的情况了。

疑问

那么还有个问题,为什么在模型训练阶段使用loss相加不会出现叠加计算图的情况呢?有兴趣的小伙伴可以自己去查一下!

你可能感兴趣的:(神经网络,深度学习,pytorch)