RuntimeError: Assertion cur_target 0 cur_target n_classes failed

问题描述

使用pytorch的函数 torch.nn.CrossEntropyLoss()计算Loss时报错:

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed

报错原因

直观上看,函数要求目标分类数大于等于0并且小于等于输入的类别。所以一般而言,都是网络中输出的种类数和标签中设置的种类数量不同造成的。

解决方案

针对于不同原因,主要从两方面考虑解决。

方向一:模型输出与分类数不一致

  1. 看一下模型的输出尺寸与分类数差异是否明显,核查代码是否存在错误。
  2. 如果没有错误,只是映射维度不对,可以考虑在模型的最后一层加一层FC层,将输出尺寸映射到分类大小。

方向二:标签的设置不是从0开始

  1. 如果模型的输出尺寸与分类数大小相同,看一下标签的设定是否是从0开始的。
  2. 如果标签是从1开始设置的,重新设置标签。这里存在的坑是:在使用CrossEntropyLoss()这个函数进行验证时,标签必须从0开始设置,否则便会报错。

你可能感兴趣的:(Ubuntu,Python)