U-Net 是一种用于生物医学图像分割的卷积神经网络架构。它通过引入一种新颖的网络结构和训练策略解决了传统方法在数据量不足时面临的挑战。U-Net 的主要思想是利用数据增强技术来高效利用有限的标注样本,并通过独特的网络设计来提高分割精度。
U-Net 的主要贡献包括:
1、数据增强策略:使用随机弹性变形和其他形式的数据增强来增加训练数据的多样性,从而在有限的数据集上训练出更强大的模型。
2、U 形网络结构:包含一个收缩路径(downsampling path)用于捕获上下文信息,以及一个对称的扩展路径(upsampling path)用于精确定位。
3、快速推理:U-Net 能够在现代 GPU 上快速执行,对于 512x512 的图像,分割只需要不到一秒钟的时间。
4、高性能:在 ISBI 挑战赛中,U-Net 在神经元结构分割和细胞追踪任务中表现出了卓越的性能。
具体的解析可以看这篇文章:Unet详解
①Unet原文里面使用3*3卷积的时候没有用一下padding,所以卷积过程当中输入的图像在还原的时候会有像素损失,所以在实现的时候我使用了padding=1,方式是“reflect”的方法。
②原论文发表的时候Batch Norm还不流行,因此实现的时候我加了上去
③网络的上采样层我的实现方法是:双线性插值+卷积层运算,也可以使用转置卷积
④下采样原论文使用的是最大池化,实现的时候我使用的是stride=2的卷积层,我认为这样拟合能力更好。
①model.py
import torch
from torch import nn
from torch.nn import functional as F
class conv_block(nn.Module):
def __init__(self,in_channel,out_channel):
super(conv_block,self).__init__()
self.layer=nn.Sequential(
nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,stride=1,padding=1,padding_mode="reflect"),
nn.BatchNorm2d(out_channel),
nn.Dropout(0.5),
nn.LeakyReLU()
)
def forward(self,x):
x=self.layer(x)
return x
class downsample(nn.Module):
def __init__(self,input_channel):
super(downsample,self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(input_channel,input_channel, kernel_size=3, stride=2, padding=1, padding_mode="reflect"),
nn.BatchNorm2d(input_channel),
nn.LeakyReLU()
)
def forward(self,x):
x=self.layer(x)
return x
class upsample(nn.Module):
def __init__(self,input_channel,output_channel):
super(upsample,self).__init__()
self.layer=nn.Sequential(
nn.Conv2d(input_channel, output_channel, kernel_size=3, stride=1, padding=1, padding_mode="reflect"),
nn.BatchNorm2d(output_channel),
nn.LeakyReLU()
)
def forward(self,x):
x=F.interpolate(x,scale_factor=2,mode='bilinear')
x=self.layer(x)
return x
class Unet(nn.Module):
def __init__(self,num_class):
super(Unet,self).__init__()
self.con1=conv_block(3,64)
self.con2=conv_block(64,64)
self.down1=downsample(64)
self.con3=conv_block(64,128)
self.con4=conv_block(128,128)
self.down2=downsample(128)
self.con5=conv_block(128,256)
self.con6=conv_block(256,256)
self.down3=downsample(256)
self.con7=conv_block(256,512)
self.con8=conv_block(512,512)
self.down4=downsample(512)
self.con9 = conv_block(512, 1024)
self.con10 = conv_block(1024, 1024)
#--------------------------上面是encoder部分------------------------------
self.up1=upsample(1024,512)
self.con11 = conv_block(1024, 512)
self.con12= conv_block(512