基于pytorch 的 orthogonal_regularization(正交规范化)实现

认为参数需要满足一定条件,希望卷积层参数是正交的。
如果不是正交的,计算与正交之间的距离,然后作为损失进行优化。

本程序给出了orthogonal regularization的pytorch的实现,直接返回模型的损失。

import torch
def orthogonal_regularization(model, device, beta=1e-4):
    r"""
        author: Xu Mingle
        time: 2019年2月19日15:12:43
        input:
            model: which is the model we want to use orthogonal regularization, e.g. Generator or Discriminator
            device: cpu or gpu
            beta: hyperparameter
        output: loss
    """
    
    # beta * (||W^T.W * (1-I)||_F)^2 or 
    # beta * (||W.W.T * (1-I)||_F)^2
    # 若 H < W,可以使用前者, 若 H > W, 可以使用后者,这样可以适当减少内存
    
    
    loss_orth = torch.tensor(0., dtype=torch.float32, device=device)
    
    for name, param in model.named_parameters():
#         print('name is {}'.format(name))
#         print('shape is {}'.format(param.shape))
        if 'weight' in name and param.requires_grad and len(param.shape)==4:
        # 是weight,而不是bias
        # 当然是指定被训练的参数
        # 只对卷积层参数做这样的正则化,而不包括嵌入层(维度是2)等。
            
#             print('shape is {}'.format(param.shape))
#             print('name {}'.format(name))
            
            N, C, H, W = param.shape
#             print('param shape {}'.format(param.shape))
            
            weight = param.view(N * C, H, W)
#             print('flatten shape {}'.format(weight.shape))
            
            weight_squared = torch.bmm(weight, weight.permute(0, 2, 1)) # (N * C) * H * H
#             print('beta_squared shape {}'.format(weight_squared.shape))
            
            ones = torch.ones(N * C, H, H, dtype=torch.float32) # (N * C) * H * H
#             print('ones shape {}'.format(ones.shape))
            
            diag = torch.eye(H, dtype=torch.float32) # (N * C) * H * H
#             print('diag shape {}'.format(diag.shape))
            
            loss_orth += ((weight_squared * (ones - diag).to(device)) ** 2).sum()
            
    return loss_orth * beta

你可能感兴趣的:(十年磨一剑,数学理论,pytorch)