在修改网络结构之后,想用hidden layer库
绘图观察一下网络结构是否修改正确
结果报错 RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
将 __init__
中定义模块使用的list
转换为ModuleList
from batchgenerators.utilities.file_and_folder_operations import join
import hiddenlayer as hl
g = hl.build_graph(model, torch.rand((1, 3, 48, 192, 192)), transforms=None)
arch_folder = "/home/anaconda3/envs/py3.7_torch1.8.1/lib/python3.7/site-packages/nnunet/network_architecture"
g.save(join(arch_folder, "Res_UNet_architecture.pdf"))
print(model)
和 y = model(x)
都没有报错 File "/home/anaconda3/envs/py3.7_torch1.8.1/lib/python3.7/site- packages/nnunet/network_architecture/generic_UNet.py", line 294, in forward
x = self.blocks[0](x)
File "/home/anaconda3/envs/py3.7_torch1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 887, in _call_impl`
网络里的这个模块定义如下,代码过长,仅保留有关的部分
class StackedConvLayersWithRes(nn.Module):
def __init__(self, input_feature_channels, output_feature_channels, num_convs,
conv_op=nn.Conv2d, conv_kwargs=None,
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin):
'''
代码过长,仅保留有关的部分 此处省略很多定义....
关键在下面三句
'''
self.blocks = []
self.blocks.append( basic_block(input_feature_channels, output_feature_channels, self.conv_op,
self.conv_kwargs_first_conv,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
self.nonlin, self.nonlin_kwargs))
for _ in range(num_convs - 1):
self.blocks.append( basic_block(output_feature_channels, output_feature_channels, self.conv_op,
self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
self.nonlin, self.nonlin_kwargs) )
def forward(self, x):
# 如果要加res的话 要改在这里
x = self.blocks[0](x) # 第一个con会做下采样 channel数目不一致
skip = x
# 取这里的feature map做Res
for i in range(1,len(self.blocks)):
x = self.blocks[i](x)
return skip + x
问题出在这里,self.blocks使用的是list
只需要在l构建好之后,在后面加一句
self.blocks = nn.ModuleList(self.blocks)
即可,使用的时候可以一样使用下标访问ModuleList中的内容,所以forword代码也不用改。
运行之后发现不会报错了!