【python错误】Pytorch1.9 ImportError: cannot import name ‘zero_gradients‘

错误:Pytorch1.9 ImportError: cannot import name ‘zero_gradients’

【python错误】Pytorch1.9 ImportError: cannot import name ‘zero_gradients‘_第1张图片

错误提示:

ImportError: cannot import name ‘zero_gradients’ from ‘torch.autograd.gradcheck’ (/root/miniconda3/envs/d2l/lib/python3.9/site-packages/torch/autograd/gradcheck.py)

原因:

pytorch版本更新后,没有对应的方法函数

解决:

将~/miniconda3/envs/d2l/lib/python3.9/site-packages/advertorch/attacks/fast_adaptive_boundary.py
中的
from torch.autograd.gradcheck import zero_gradients
删掉,加入

def zero_gradients(x):
    if isinstance(x, torch.Tensor):
        if x.grad is not None:
            x.grad.detach_()
            x.grad.zero_()
    elif isinstance(x, collections.abc.Iterable):
        for elem in x:
            zero_gradients(elem)

参考:https://zhuanlan.zhihu.com/p/420312739

你可能感兴趣的:(python,开发语言)