learn from Restormer. Restormer: Efficient Transformer for High-Resolution Image Restoration | IEEE Conference Publication | IEEE Xplore
在小的crop patches上训练Transformer模型可能无法对全局图像统计进行编码,从而在测试时对全分辨率图像的效果不佳。作者提出渐进式学习,其中网络在早期的时代在较小的图像patch上进行训练,在后期的训练时代在逐渐变大的patch上进行训练。patch大时减小batch size。
# batch: 8
# mini_batch_sizes: [8,5,4,2,1,1]
# iters: [92000,64000,48000,36000,36000,24000]
# gt_size: 384 # Max patch size for progressive training
# gt_sizes: [128,160,192,256,320,384]
# scale = 1.
# groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))])
# groups: [92000, 156000, 204000, 240000, 276000, 300000]
j = ((current_iter > groups) != True).nonzero()[0]
if len(j) == 0:
bs_j = len(groups) - 1
else:
bs_j = j[0]
mini_gt_size = mini_gt_sizes[bs_j]
mini_batch_size = mini_batch_sizes[bs_j]
lq = train_data['lq'] # train_data为pytorch DataLoader返回的(b, c, h, w) tensor
gt = train_data['gt']
if mini_batch_size < batch_size:
indices = random.sample(range(0, batch_size), k=mini_batch_size)
lq = lq[indices]
gt = gt[indices]
if mini_gt_size < gt_size:
x0 = int((gt_size - mini_gt_size) * random.random())
y0 = int((gt_size - mini_gt_size) * random.random())
x1 = x0 + mini_gt_size
y1 = y0 + mini_gt_size
lq = lq[:, :, x0:x1, y0:y1]
gt = gt[:, :, x0 * scale:x1 * scale, y0 * scale:y1 * scale]
Python nonzero(a): 返回数组a中非零元素的索引值tuple。
如上例中当current_iter=0,(current_iter > groups) != True 结果为[True, True, True, True, True, True],则nozero返回(array([0,1,2,3,4,5]),) ,nozero[0] = [0,1,2,3,4,5].
random.sample(sequence, k) sequence: 可以是一个列表,元组,字符串,或集合
从序列sequence中选择元素的k长度的新列表。
random.random() :该方法返回一个0到1之间的随机浮动数。
作者经过消融实验发现:similar parameters/FLOPs budget,深且transformer block 中dim小(即窄)的网络更精准,而宽且dim大的网络速度更快。
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1,bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
# depth-wise conv
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
class Downsample(nn.Module):
def __init__(self, n_feat):
super(Downsample, self).__init__()
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelUnshuffle(2))
def forward(self, x):
return self.body(x)
与标准下采样方法不同,首先用conv再用pixel-unshuffle。