【DeepLearning-7】 CNN 和Transformer的混合神经网络结构

构造函数__init__

def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
  • 初始化函数定义了网络的主要结构和参数。
  • channel: 输入特征的通道数。
  • dim: Transformer部分的特征维度。
  • depth: Transformer的层数。
  • kernel_size: 卷积层的核大小。
  • patch_size: 将图像分割为patches的尺寸。
  • mlp_dim: Transformer中前馈网络的维度。
  • dropout: Dropout比率,用于正则化。

网络层的定义

self.mv01 = IRBlock(channel, channel)
self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
self.conv3 = conv_1x1_bn(dim, channel)
self.conv2 = conv_1x1_bn(channel, dim)
self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)
self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
  • IRBlockconv_nxn_bn, conv_1x1_bn 用于特征提取和维度变换。
  • UserDefined 是之前提到的基于Transformer的结构,用于处理序列数据。
  • 这些层的组合利用了CNN的空间特征提取能力和Transformer的序列处理能力。

def conv_1x1_bn(inp, oup):

    return nn.Sequential(

        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),

        nn.BatchNorm2d(oup),

        nn.SiLU()

    )

def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):

    return nn.Sequential(

        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),

        nn.BatchNorm2d(oup),

        nn.SiLU()

    )

 

前向传播 forward

def forward(self, x):
        y = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        z = x.clone()
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
        x = self.conv3(x)
        x = torch.cat((x, z), 1)
        x = self.conv4(x)
        x = x + y
        x = self.mv01(x)
        return x
  • forward方法定义了数据通过网络的流程。
  • 输入 x 首先经过几个卷积层进行特征提取和维度变换。
  • 输入被重组(rearrange),准备送入Transformer结构。
  • Transformer处理重组后的数据再被重组回原来的形状。
  • 经过进一步的卷积处理后,使用残差连接,并通过另一个 IRBlock

完整代码: 

class MobileViTBv3(nn.Module):
    def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size
        self.mv01 = IRBlock(channel, channel) 
        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv2 = conv_1x1_bn(channel, dim)
        self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
    def forward(self, x):
        y = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        z = x.clone()
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
        x = self.conv3(x)
        x = torch.cat((x, z), 1)
        x = self.conv4(x)
        x = x + y
        x = self.mv01(x)
        return x

你可能感兴趣的:(神经网络,cnn,transformer)