Transformer实战-系列教程9:SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)

Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

3、PatchEmbed类

3.1 构造函数

class PatchEmbed(nn.Module):
	def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
    
  • img_size 定义了输入图像的尺寸,默认为 224x224 像素。
  • patch_size 定义了每个patch的大小,默认为 4x4 像素,这意味着每个patch包含 4x4=16 个像素。
  • in_chans 指定了输入图像的通道数,默认为 3,对应于常见的 RGB 图像。
  • embed_dim 定义了线性投影的输出通道数,即每个patch的嵌入维度,默认为 96。
  • norm_layer 是一个可选的标准化层,用于在嵌入后应用标准化。
  • self.proj 是一个卷积层,用于将输入图像的每个patch转换成嵌入向量。其使用了与patch大小相同的卷积核和步长,确保图像被分割成不重叠的patch。

3.2 前向传播

def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C 
        if self.norm is not None:
            x = self.norm(x)
        return x

patch_embbeding,主要通过Swin_Transformer.py的PatchEmbed类实现
对PatchEmbed类的前向传播进行debug:

  • 原始输入x: torch.Size([4, 3, 224, 224]),输入是一张彩色图像
  • 将x的shape值转化为,B, C, H, W
  • 检查输入图像的尺寸是否与模型期望的尺寸相匹配
  • torch.Size([4, 3136, 96]),使用 self.proj 卷积层将图像分割成patch并进行线性投影。操作后,张量被展平并转置,以匹配 (B, N, C) 的形状,其中 N 是patch数量,C 是嵌入维度。3136=56*566,56=224/4卷积得到
  • torch.Size([4, 3136, 96]),如果定义了 norm_layer,则在嵌入向量上应用标准化

4、BasicLayer类

4.1 构造函数

class BasicLayer(nn.Module):
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

主要参数解析:

  • dim (int): 输入通道数
  • input_resolution (tuple[int]): Input resolution.
  • depth (int): 在BasicLayer中堆叠的Swin Transformer Block的数量
  • num_heads (int): 多头注意力的头数
  • window_size (int): 窗口大小,决定了注意力机制的局部范围
  • mlp_ratio (float): MLP隐层大小与输入通道数的比率
  • qkv_bias (bool, optional): q、k、v是否要设置的可学习偏执
  • qk_scale (float | None, optional): q、k、v的缩放因子
  • drop (float, optional): Dropout 的比例,默认为0
  • attn_drop (float, optional): Attention dropout 的比例,默认为0
  • drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  • norm_layer (nn.Module, optional): 从torch中导入的层归一化
  • downsample (nn.Module | None, optional): 从torch中导入的下采样
  • use_checkpoint (bool): 是否采用checkpoint保存训练过程来节省内存

4.2 前向传播

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

BasicLayer类的实例,在SwinTransformer类中被一个for循环多次调用,因此每次调用的情况会产生变化

原始输入x: torch.Size([4, 3136, 96])
blb(x): torch.Size([4, 3136, 96])
blb(x): torch.Size([4, 3136, 96])
torch.Size([4, 784, 192])

原始输入x: torch.Size([4, 784, 192])
blb(x): torch.Size([4, 784, 192])
blb(x): torch.Size([4, 784, 192])
torch.Size([4, 196, 384])

原始输入x: torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
blb(x): torch.Size([4, 196, 384])
torch.Size([4, 49, 768])

原始输入x: torch.Size([4, 49, 768])
blb(x): torch.Size([4, 49, 768])
blb(x): torch.Size([4, 49, 768])
torch.Size([4, 49, 768])

每4次一个循环,最开始的3136是序列长度,96是每个向量的维度,序列长度不断变短,而向量维度在增加,这个变化主要是在后面的下采样中产生的变化,在循环多次调用SwinTransformerBlock中维度没有产生变化

BasicLayer 提供了 Swin Transformer 模型中一个阶段的完整实现,包括自注意力块的堆叠、可选的下采样处理,以及对应的配置选项,如自注意力的头数、窗口大小等。通过这种模块化的设计,Swin Transformer 能够灵活地适应不同的任务和数据集,同时保持较高的计算效率和表示能力

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

你可能感兴趣的:(Transformer实战,transformer,pytorch,深度学习,计算机视觉,SwinTransformer,人工智能)