从零搭建Pytorch模型教程(三)搭建Transformer网络

前言

本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍。

在讲如何搭建之前,先回顾一下Transformer在计算机视觉中的结构是怎样的。这里以最典型的ViT为例。

从零搭建Pytorch模型教程(三)搭建Transformer网络_第1张图片

如图所示,对于一张图像,先将其分割成NxN个patches,把patches进行Flatten,再通过一个全连接层映射成tokens,对每一个tokens加入位置编码(position embedding),会随机初始化一个tokens,concate到通过图像生成的tokens后,再经过transformer的Encoder模块,经过多层Encoder后,取出最后的tokens(即随机初始化的tokens),再通过全连接层作为分类网络进行分类。

下面我们就根据这个流程来一步一步介绍如何搭建一个Transformer模型。、

分块


目前有两种方式实现分块,一种是直接分割,一种是通过卷积核和步长都为patch大小的卷积来分割。

直接分割

直接分割即把图像直接分成多块。在代码实现上需要使用einops这个库,完成的操作是将(B,C,H,W)的shape调整为(B,(H/P *W/P),P*P*C)。

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

self.to_patch_embedding = nn.Sequential(
           Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
           nn.Linear(patch_dim, dim),
      )

这里简单介绍一下Rearrange。

Rearrange用于对张量的维度进行重新变换排序,可用于替换pytorch中的reshape,view,transpose和permute等操作。举几个例子

#假设images的shape为[32,200,400,3]
#实现view和reshape的功能
Rearrange(images,'b h w c -> (b h) w c')#shape变为(32*200, 400, 3)
#实现permute的功能
Rearrange(images, 'b h w c -> b c h w')#shape变为(32, 3, 200, 400)
#实现这几个都很难实现的功能
Rearrange(images, 'b h w c -> (b c w) h')#shape变为(32*3*400, 200)

从这几个例子看可以看出,Rearrange非常简单好用,这里的b, c, h, w都可以理解为表示符号,用来表示操作变化。通过这几个例子似乎也能理解下面这行代码是如何将图像分割的。

Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)

这里需要解释的是,一个括号内的两个变量相乘表示的是该维度的长度,因此不要把"h"和"w"理解成图像的宽和高。这里实际上h = H/p1, w = W/p2,代表的是高度上有几块,宽度上有几块。h和w都不需要赋值,代码会自动根据这个表达式计算,b和c也会自动对应到输入数据的B和C。

后面的"b (h w) (p1 p2 c)"表示了图像分块后的shape: (B,(H/P *W/P),P*P*C)

这种方式在分块后还需要通过一层全连接层将分块的向量映射为tokens。

在ViT中使用的就是这种直接分块方式。

卷积分割

卷积分割比较容易理解,使用卷积核和步长都为patch大小的卷积对图像卷积一次就可以了。

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C

在swin transformer中即使用的是这种卷积分块方式。在swin transformer中卷积后没有再加全连接层。

Position Embedding


Position Embedding可以分为absolute position embedding和relative position embedding。

在学习最初的transformer时,可能会注意到用的是正余弦编码的方式,但这只适用于语音、文字等1维数据,图像是高度结构化的数据,用正余弦不合适

在ViT和swin transformer中都是直接随机初始化一组与tokens同shape的可学习参数,与tokens相加,即完成了absolute position embedding。

在ViT中实现方式:

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x += self.pos_embedding[:, :(n + 1)]
#之所以是n+1,是因为ViT中选择随机初始化一个class token,与分块得到的tokens拼接。所以patches的数量为num_patches+1。

在swin transformer中的实现方式:

from timm.models.layers import trunc_normal_
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)

在TimeSformer中的实现方式:

self.pos_emb = torch.nn.Embedding(num_positions + 1, dim)

以上就是简单的使用方法,这种方法属于absolute position embedding。

Encoder


Encoder由Multi-head Self-attention和FeedForward组成。

你可能感兴趣的:(pytorch,transformer,网络,人工智能,ai,语言模型,机器学习)