Progressive learning

learn from Restormer. Restormer: Efficient Transformer for High-Resolution Image Restoration | IEEE Conference Publication | IEEE Xplore

Progressive learning

在小的crop patches上训练Transformer模型可能无法对全局图像统计进行编码,从而在测试时对全分辨率图像的效果不佳。作者提出渐进式学习,其中网络在早期的时代在较小的图像patch上进行训练,在后期的训练时代在逐渐变大的patch上进行训练。patch大时减小batch size。

# batch: 8
# mini_batch_sizes: [8,5,4,2,1,1]  
# iters: [92000,64000,48000,36000,36000,24000]
# gt_size: 384   # Max patch size for progressive training
# gt_sizes: [128,160,192,256,320,384]
# scale = 1.
# groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))])
# groups: [92000, 156000, 204000, 240000, 276000, 300000]

  
j = ((current_iter > groups) != True).nonzero()[0]
if len(j) == 0:
    bs_j = len(groups) - 1
else:
    bs_j = j[0]

mini_gt_size = mini_gt_sizes[bs_j]
mini_batch_size = mini_batch_sizes[bs_j]


lq = train_data['lq']  # train_data为pytorch DataLoader返回的(b, c, h, w) tensor
gt = train_data['gt']

if mini_batch_size < batch_size:
    indices = random.sample(range(0, batch_size), k=mini_batch_size)
    lq = lq[indices]
    gt = gt[indices]

if mini_gt_size < gt_size:
    x0 = int((gt_size - mini_gt_size) * random.random())
    y0 = int((gt_size - mini_gt_size) * random.random())
    x1 = x0 + mini_gt_size
    y1 = y0 + mini_gt_size
    lq = lq[:, :, x0:x1, y0:y1]
    gt = gt[:, :, x0 * scale:x1 * scale, y0 * scale:y1 * scale]

Python nonzero(a): 返回数组a中非零元素的索引值tuple。

如上例中当current_iter=0,(current_iter > groups) != True 结果为[True, True, True, True, True, True],则nozero返回(array([0,1,2,3,4,5]),) ,nozero[0] = [0,1,2,3,4,5].

random.sample(sequence, k)  sequence: 可以是一个列表,元组,字符串,或集合

从序列sequence中选择元素的k长度的新列表。

random.random() :该方法返回一个0到1之间的随机浮动数。

Deeper or wider network

作者经过消融实验发现:similar parameters/FLOPs budget,深且transformer block 中dim小(即窄)的网络更精准,而宽且dim大的网络速度更快。

Restormer中的网络架构

1.Gated-Dconv Feed-Forward Network (GDFN)

class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1,bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
        # depth-wise conv
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

2.Multi-DConv Head Transposed Self-Attention (MDTA) 

class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

3. Downsample 

class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

与标准下采样方法不同,首先用conv再用pixel-unshuffle。

你可能感兴趣的:(深度学习,人工智能,计算机视觉)