Python和PyTorch对比实现批标准化Batch Normalization函数及反向传播

摘要

本文使用纯 Python 和 PyTorch 对比实现 Batch Normalization 函数及其反向传播.

相关

原理和详细解释, 请参考文章 :

Batch Normalization函数详解及反向传播中的梯度求导

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

import torch
import numpy as np


class BatchNorm1d:
    def __init__(self):
        self.eps = 1e-5
        self.weight = None
        self.bias = None

        self.num = None
        self.std = None
        self.dw = None
        self.db = None

    def __call__(self, x):
        self.num = np.shape(x)[0]
        mean = np.mean(x, axis=0, keepdims=True)
        var = np.var(x, axis=0, keepdims=True)
        self.sqrt = np.sqrt(var + self.eps)
        self.std = (x - mean) / self.sqrt
        out = self.std * self.weight + self.bias
        return out

    def backward(self, d_loss):
        std_t = self.std.T
        shape_t = np.shape(std_t)
        r = np.zeros([shape_t[0], shape_t[1], shape_t[1]])
        shift_eye = np.eye(shape_t[1]) * shape_t[1] - 1
        for i in range(shape_t[0]):
            r[i] = std_t[i][:, np.newaxis] * std_t[i][np.newaxis, :]
            r[i] = shift_eye - r[i]

        u = self.weight / shape_t[1] / self.sqrt
        u = u.T
        y = r * u[:, np.newaxis]

        dx = np.zeros(shape_t)
        for i in range(shape_t[0]):
            dx[i] = np.dot(d_loss.T[i], y[i])
        dx = dx.T

        self.dw = np.sum(self.std * d_loss, axis=0)
        self.db = np.sum(d_loss, axis=0)

        return dx


np.set_printoptions(precision=8, suppress=True, linewidth=120)
np.random.seed(123)
torch.random.manual_seed(123)

x_numpy = np.array(np.random.random((3, 5)), dtype=np.float64)
weight_numpy = np.array(np.random.random((5,)), dtype=np.float64)
bias_numpy = np.array(np.random.random((5,)), dtype=np.float64)
d_loss_numpy = np.array(np.random.random((3, 5)), dtype=np.float64)

x_tensor = torch.tensor(x_numpy, requires_grad=True)
weight_tensor = torch.tensor(weight_numpy, requires_grad=True)
bias_tensor = torch.tensor(bias_numpy, requires_grad=True)
d_loss_tensor = torch.tensor(d_loss_numpy, requires_grad=True)

batch_norm_numpy = BatchNorm1d()
batch_norm_numpy.weight = weight_numpy
batch_norm_numpy.bias = bias_numpy

batch_norm_tensor = torch.nn.BatchNorm1d(5).double()
batch_norm_tensor.weight = torch.nn.Parameter(weight_tensor, requires_grad=True)
batch_norm_tensor.bias = torch.nn.Parameter(bias_tensor, requires_grad=True)

output_numpy = batch_norm_numpy(x_numpy)
output_tensor = batch_norm_tensor(x_tensor)
output_tensor.backward(d_loss_tensor)

dx_numpy = batch_norm_numpy.backward(d_loss_numpy)
dx_tensor = x_tensor.grad

dw_numpy = batch_norm_numpy.dw
dw_tensor = batch_norm_tensor.weight.grad

db_numpy = batch_norm_numpy.db
db_tensor = batch_norm_tensor.bias.grad

print("output_numpy \n", output_numpy)
print("output_tensor \n", output_tensor.data.numpy())

print("dx_numpy \n", dx_numpy)
print("dx_tensor \n", dx_tensor.data.numpy())

print("dw_numpy \n", dw_numpy)
print("dw_tensor \n", dw_tensor.data.numpy())

print("db_numpy \n", db_numpy)
print("db_tensor \n", db_tensor.data.numpy())

"""
代码输出 :
output_numpy 
 [[ 1.65328862  0.60845792  0.51520137  1.06970231  1.47430633]
 [ 0.31989274  1.049903    0.94450061  0.89737849  0.33620922]
 [-0.06997848  0.88993446  0.713664   -0.13401027  0.3568146 ]]
output_tensor 
 [[ 1.65328862  0.60845792  0.51520137  1.06970231  1.47430633]
 [ 0.31989274  1.049903    0.94450061  0.89737849  0.33620922]
 [-0.06997848  0.88993446  0.713664   -0.13401027  0.3568146 ]]
dx_numpy 
 [[ 0.14897849 -0.00280487 -0.19168465 -0.12787269 -0.00214988]
 [-0.65806077 -0.00492716 -0.16475823  0.14902235 -0.12571213]
 [ 0.50908229  0.00773203  0.35644287 -0.02114966  0.127862  ]]
dx_tensor 
 [[ 0.14897849 -0.00280487 -0.19168465 -0.12787269 -0.00214988]
 [-0.65806077 -0.00492716 -0.16475823  0.14902235 -0.12571213]
 [ 0.50908229  0.00773203  0.35644287 -0.02114966  0.127862  ]]
dw_numpy 
 [ 0.10859242  0.09332669  0.21318384 -0.80395148  0.23776771]
dw_tensor 
 [ 0.10859242  0.09332669  0.21318384 -0.80395148  0.23776771]
db_numpy 
 [ 0.72732508  1.22184114  1.55251516  1.73155916  1.55864309]
db_tensor 
 [ 0.72732508  1.22184114  1.55251516  1.73155916  1.55864309]
"""

你可能感兴趣的:(深度学习编程)