更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题
阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。
U-Net V2 是一种先进的医学图像分割模型,它通过改进的跳跃连接和语义细节注入(SDI)模块,有效地融合了高级语义信息和低级细节信息,从而显著提升了分割精度。相比原始 U-Net,U-Net V2 在多个数据集上表现出更高的性能,同时保持了较低的内存占用和计算复杂度,使其在医学图像分割领域具有广泛的应用前景。其网络结构如下图。
点击链接,将其clone到本地后,进入文件夹内,pip install -e . 即可(注意-e后有个点)。
运行nnUNetv2_plan_and_preprocess命令,也是预处理命令,生成nnUNetPlans.json文件。
由于读者预期替换的网络不一定是本文替换的网络,且所用数据集不一定是本文所用数据集,所以
本文主要介绍如何修改网络,使修改后的nnU-Net V2可以正常训练,给读者提供实践样例,不会涉及评估指标。
在对相同网络结构进行修改时,读者往往会有不同的修改思路 。在这一过程中,请读者以成功运行代码、是否便于读者后续回顾作为考量标准,确定修改思路。
U-Net V2 网络结构包括编码器、语义和细节注入(SDI)模块以及解码器三部分。
编码器用于提取多级特征,与nnU-Net V2默认一致。
SDI模块对编码器(下采样)产生的多层特征图应用空间注意力机制和通道注意力机制,再通过哈达马乘积将高级特征的语义信息与低级特征的细节信息融合。且U-Net V2在下采样结束后才对特征图添加注意力机制、融合多层特征图。
解码器利用融合后的多级特征图进行上采样,nnU-Net V2默认解码器的各层通道数与编码器相同,在U-Net V2中,解码器各层通道数则是相同值。
打开nnUNet \ DATASET \ nnUNet_preprocessed \ Dataset001_ACDC \ nnUNetPlans.json文件,查看configurations --> 2d --> architecture --> network_class_name字段,默认为dynamic_network_architectures.architectures.unet.PlainConvUNet。
根据network_class_name字段,找到PlainConvUNet类所在文件:dynamic-network-architectures-main \ dynamic_network_architectures \ architectures \ unet.py
PlainConvUNet类就是nnU-Net默认的U-Net,其结构由编码器和解码器两部分组成,很标准,很常见。具体代码、结构见PlainConvUNet类
nnU-Net V2 采用的 U-Net ,与 U-Net V2 的主要区别在于缺少 SDI 模块。因此,我们不修改network_class_name字段,不使用完全自定义的网络结构,而是对 PlainConvUNet 类进行“改造”,加入 SDI 模块。
我们将修改过程分为三步:
本次替换的一些设置
训练配置 | 2d |
更换的网络 | U-Net V2 |
涉及的文件(加粗文件是要修改的文件):
dynamic_network_architectures/
|__ architectures/
| |__ resnet.py
| |__ unet.py
| |__ vgg.py
|__ building_blocks/
| |__ helper.py
| |__ plain_conv_encoder.py
| |__ regularization.py
| |__ residual.py
| |__ residual_encoders.py
| |__ simple_conv_blocks.py
| |__ unet_decoder.py
| |__ unet_residual_decoder.py
|__ initialization/
| |__ weight_init.py
由于有三步,代码一起放上来会乱,所以只会放当前步骤涉及的代码
在PlainConvUNet类的forward函数中,self.encoder和self.decoder分别是U-Net的编码器和解码器,skips是一个list,存储下采样产生的特征图:
def forward(self, x):
skips = self.encoder(x)
return self.decoder(skips)
遍历skips,对每一个元素(特征图)应用注意力机制。
空间注意力机制和通道注意力机制代码很常见,搜索或者AI都容易,在unet.py文件中加入这两个类:
# 空间注意力机制
class SpatialAttention(nn.Module):
def __init__(self, channels):
super(SpatialAttention, self).__init__()
self.sigm = nn.Sigmoid()
self.conv2d = nn.Conv2d(2, 1, kernel_size=1, bias=False)
self.channels = channels
def forward(self, x):
x1 = torch.mean(x, dim=1, keepdim=True)
x2, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat((x1, x2), dim=1)
out = self.conv2d(out)
return self.sigm(out) * x
def compute_conv_feature_map_size(self, input_size):
output = 2 * np.prod(input_size, dtype=np.int64) # mean和max的特征图
output += output # cat后
output += 1 * np.prod(input_size, dtype=np.int64) # 卷积输出特征图
output += np.prod([self.channels, *input_size], dtype=np.int64) # return
return output
# 通道注意力机制
class ChanAttention(nn.Module):
def __init__(self, channel):
super(ChanAttention, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.maxpool = nn.AdaptiveMaxPool2d(1)
self.linear = nn.Sequential(
nn.Linear(channel, channel // 16, bias=False),
nn.LeakyReLU(inplace=True),
nn.Linear(channel // 16, channel, bias=False),
nn.Sigmoid()
)
self.channel = channel
def forward(self, x):
b, c, _, _ = x.size()
_avg = self.linear(self.avgpool(x).view(b, c))
_max = self.linear(self.maxpool(x).view(b, c))
out = (_avg + _max).view(b, c, 1, 1)
return x * out.expand_as(x)
def compute_conv_feature_map_size(self, input_size):
output = 2 * np.prod([self.channel, *input_size], dtype=np.int64) # return部分
output += 2 * self.channel * (self.channel // 16) # 两个linear层
output += 2 * self.channel # avgpool层 + maxpool层
return output
compute_conv_feature_map_size函数用于预估显存使用量(在PlainConvUNet类中介绍过)。
接下来需要在PlainConvUNet类的__init__函数中实例化这两个类,下面是要直接加入的代码:
# 用于计算SDI模块占用的显存量
self.features_per_stage = features_per_stage
# 注意力机制
channel_attentions = []
spatial_attentions = []
for i in range(len(features_per_stage)):
# 注意力
channel_attentions.append(ChanAttention(features_per_stage[i]))
spatial_attentions.append(SpatialAttention(features_per_stage[i]))
# 注意力
self.chan_atten = nn.Sequential(*channel_attentions)
self.spa_atten = nn.Sequential(*spatial_attentions)
在PlainConvUNet类的forward函数中加入注意力机制,下面是要替换的代码:
def forward(self, x):
skips = self.encoder(x)
for i in range(len(skips)):
# 注意力机制
skips[i] = self.chan_atten[i](skips[i])
skips[i] = self.spa_atten[i](skips[i])
return self.decoder(skips)
在PlainConvUNet类的compute_conv_feature_map_size函数中加入两个类,下面是要替换的代码:
def compute_conv_feature_map_size(self, input_size):
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
"Give input_size=(x, y(, z))!"
output = self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
for i in range(len(self.features_per_stage)):
# 通道和空间注意力
output += self.chan_atten[i].compute_conv_feature_map_size(input_size)
output += self.spa_atten[i].compute_conv_feature_map_size(input_size)
input_size = [j // 2 for j in input_size]
return output
这一步结束后,读者可以先训练nnU-Net V2,看看加入注意力机制是否提高了评估指标
由于各级特征图通道数、尺寸均不相同,所以先将通道数转换为相同值,再修改尺寸,进行融合。
在unet.py文件中直接加入SDI类:
class SDI(nn.Module):
def __init__(self, same_channel, len):
super().__init__()
self.same_channel = same_channel
self.len = len
self.convs = nn.ModuleList(
[nn.Conv2d(same_channel, same_channel, kernel_size=3, stride=1, padding=1) for _ in range(len)])
def forward(self, xs, anchor):
ans = torch.ones_like(anchor)
target_size = anchor.shape[-2:]
for i, x in enumerate(xs):
if x.shape[-1] > target_size[-1]:
x = F.adaptive_avg_pool2d(x, target_size)
elif x.shape[-1] < target_size[-1]:
x = F.interpolate(x, size=target_size,
mode='bilinear', align_corners=True)
ans = ans * self.convs[i](x)
return ans
def compute_conv_feature_map_size(self, input_size):
output = np.prod([self.same_channel, *input_size], dtype=np.int64) * self.len
return output
接下来在PlainConvUNet类的__init__函数中实例化这个类、直接加入通道转换的代码,循环部分可以和之前的代码合并:
# 通道转换
Trans_channel = []
sdi = []
# 定义特征融合时的统一通道数,32可以根据读者设备修改,通道数越大,显存占用越多
self.same_channel = 32
for i in range(len(features_per_stage)):
# 通道转换
Trans_channel.append(nn.Conv2d(features_per_stage[i],
self.same_channel, kernel_size=1, padding=0, bias=False))
# SDI类 特征融合
sdi.append(SDI(self.same_channel , len(features_per_stage)))
# 通道转换
self.Trans_chan = nn.Sequential(*Trans_channel)
# SDI类 特征融合
self.sdi = nn.Sequential(*sdi)
在PlainConvUNet类的forward函数中加入通道转换和特征融合,用下面的代码替换:
def forward(self, x):
skips = self.encoder(x)
for i in range(len(skips)):
# 注意力机制
skips[i] = self.chan_atten[i](skips[i])
skips[i] = self.spa_atten[i](skips[i])
# 通道转换
skips[i] = self.Trans_chan[i](skips[i])
# 特征融合
fusion = []
for i in range(len(skips)):
fusion.append(self.sdi[i](skips, skips[i]))
return self.decoder(fusion)
在PlainConvUNet类的compute_conv_feature_map_size函数中,for循环内的代码用下面的代码替换:
for i in range(len(self.features_per_stage)):
# 通道和空间注意力
output += self.chan_atten[i].compute_conv_feature_map_size(input_size)
output += self.spa_atten[i].compute_conv_feature_map_size(input_size)
output += self.sdi[i].compute_conv_feature_map_size(input_size)
input_size = [j // 2 for j in input_size]
这一步完成后,运行nnU-Net V2会报错,因为解码器部分的通道数不对。
找到解码器代码:
self.decoder = UNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,
nonlin_first=nonlin_first)
UNetDecoder类在dynamic-network-architectures-main \ dynamic_network_architectures \ building_blocks \ unet_decoder.py文件中,在其__init__参数中加入如下字段:
same_channel: int = 32
同时在实例化部分修改通道数为same_channel,按照下面的代码及注释,修改部分代码:
stages = []
transpconvs = []
seg_layers = []
for s in range(1, n_stages_encoder):
#################################### 修改开始
input_features_below = same_channel
input_features_skip = same_channel
#################################### 修改结束
stride_for_transpconv = encoder.strides[-s]
transpconvs.append(transpconv_op(
input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv,
bias=conv_bias
))
回到PlainConvUNet类的解码器代码,加入same_channel参数,注意,self.same_channel变量一定要在下面代码之前定义:
self.decoder = UNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,
nonlin_first=nonlin_first, same_channel=self.same_channel)
至此,nnU-Net V2修改完毕。
最后,根据我运行的结果,在nnU-Net V2训练时,前几个批次会会出现dice为0,且loss基本不变的情况,正常,多训练几个批次就行。