Python和PyTorch对比实现批标准化 Batch Normalization 函数在测试或推理过程中的算法

摘要

本文使用Python和PyTorch对比实现批标准化 Batch Normalization 函数在测试或推理过程中的算法.

相关

原理及详细解释, 请参考文章 :

Batch Normalization的测试或推理过程及样本参数更新方法.

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. Batch Normalization 类

文件目录 : vanilla_nn/batch_normalization.py

import numpy as np


class BatchNorm1d:
    def __init__(self, train=True, momentum=0.1, eps=1e-5):
        self.train = train
        self.momentum = momentum
        self.eps = eps

        self.weight = None
        self.bias = None

        self.std = None
        self.dw = None
        self.db = None

        self.sqrt = None
        self.std = None

        self.running_mean = None
        self.running_var = None

    def __call__(self, x):
        if self.train is True:
            mean = np.mean(x, axis=0, keepdims=True)
            var = np.var(x, axis=0, keepdims=True)
            sqrt = np.sqrt(var + self.eps)
            std = (x - mean) / sqrt
            self.sqrt = sqrt
            self.std = std

            if self.running_mean is None:
                self.running_mean = np.zeros_like(mean)
                self.running_var = np.ones_like(var)

            num = np.shape(x)[0]
            self.running_mean = (1 - self.momentum) * self.running_mean
            self.running_mean += self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var
            self.running_var += self.momentum * var * num / (num - 1)
        else:
            mean = self.running_mean
            var = self.running_var
            sqrt = np.sqrt(var + self.eps)
            std = (x - mean) / sqrt

        out = std * self.weight + self.bias
        return out

    def backward(self, d_loss):
        std_t = self.std.T
        shape_t = np.shape(std_t)
        r = np.zeros([shape_t[0], shape_t[1], shape_t[1]])
        shift_eye = np.eye(shape_t[1]) * shape_t[1] - 1
        for i in range(shape_t[0]):
            r[i] = std_t[i][:, np.newaxis] * std_t[i][np.newaxis, :]
            r[i] = shift_eye - r[i]

        u = self.weight / shape_t[1] / self.sqrt
        u = u.T
        y = r * u[:, np.newaxis]

        dx = np.zeros(shape_t)
        for i in range(shape_t[0]):
            dx[i] = np.dot(d_loss.T[i], y[i])
        dx = dx.T

        self.dw = np.sum(self.std * d_loss, axis=0)
        self.db = np.sum(d_loss, axis=0)

        return dx

2. Batch Normalization 推理测试

import torch
import numpy as np
from vanilla_nn.batch_normalization import BatchNorm1d

np.set_printoptions(precision=8, suppress=True, linewidth=120)
np.random.seed(123)
torch.random.manual_seed(123)
nums = 3

x_numpy = np.array(np.random.random((nums, 3, 5)), dtype=np.float64)
weight_numpy = np.array(np.random.random((5,)), dtype=np.float64)
bias_numpy = np.array(np.random.random((5,)), dtype=np.float64)

x_tensor = torch.tensor(x_numpy, requires_grad=True)
weight_tensor = torch.tensor(weight_numpy, requires_grad=True)
bias_tensor = torch.tensor(bias_numpy, requires_grad=True)

batch_norm_numpy = BatchNorm1d(momentum=0.2)
batch_norm_numpy.weight = weight_numpy
batch_norm_numpy.bias = bias_numpy

batch_norm_tensor = torch.nn.BatchNorm1d(5, momentum=0.2).double()
batch_norm_tensor.weight = torch.nn.Parameter(weight_tensor, requires_grad=True)
batch_norm_tensor.bias = torch.nn.Parameter(bias_tensor, requires_grad=True)

print("output_numpy")
for i in range(nums):
    output_numpy = batch_norm_numpy(x_numpy[i])
    print(output_numpy)

print("output_tensor")
for i in range(nums):
    output_tensor = batch_norm_tensor(x_tensor[i])
    print(output_tensor.data.numpy())

print("eval_numpy")
batch_norm_numpy.train = False
for i in range(nums):
    output_numpy = batch_norm_numpy(x_numpy[i])
    print(output_numpy)

print("eval_tensor")
batch_norm_tensor.eval()
for i in range(nums):
    output_tensor = batch_norm_tensor(x_tensor[i])
    print(output_tensor.data.numpy())

"""
output_numpy
[[ 0.46641115  0.18851114 -0.57237571  0.99333476  1.2092339 ]
 [ 0.01389322  1.35696463  1.83911376  0.82492271 -0.10234434]
 [-0.11841836  0.93354663  0.54244234 -0.18305345 -0.07859806]]
[[ 0.36607296  0.34287603 -0.19635263  0.74842757 -0.41712586]
 [ 0.21899343  1.48596725  1.99143145  1.05458233  1.08250983]
 [-0.22318039  0.65017912  0.01410157 -0.16780589  0.36290753]]
[[-0.16707033  1.17972157  0.14300413  0.26912336 -0.20601642]
 [ 0.08560938  1.15577836  1.97239442  1.27262803  0.03626041]
 [ 0.44334695  0.14352248 -0.30621817  0.09345263  1.19804751]]
output_tensor
[[ 0.46641115  0.18851114 -0.57237571  0.99333476  1.2092339 ]
 [ 0.01389322  1.35696463  1.83911376  0.82492271 -0.10234434]
 [-0.11841836  0.93354663  0.54244234 -0.18305345 -0.07859806]]
[[ 0.36607296  0.34287603 -0.19635263  0.74842757 -0.41712586]
 [ 0.21899343  1.48596725  1.99143145  1.05458233  1.08250983]
 [-0.22318039  0.65017912  0.01410157 -0.16780589  0.36290753]]
[[-0.16707033  1.17972157  0.14300413  0.26912336 -0.20601642]
 [ 0.08560938  1.15577836  1.97239442  1.27262803  0.03626041]
 [ 0.44334695  0.14352248 -0.30621817  0.09345263  1.19804751]]
eval_numpy
[[ 0.28282795  0.86633722  0.60257693  0.76006332  0.70923752]
 [ 0.18949885  1.31733272  1.21007392  0.71038713  0.43339924]
 [ 0.16221038  1.15390335  0.88341938  0.41306631  0.43839332]]
[[ 0.29700547  0.79904247  0.53439649  0.74611433  0.55112404]
 [ 0.26163716  1.23206321  1.26263629  0.80220571  0.71174387]
 [ 0.15530735  0.91545366  0.60444962  0.57824888  0.63467022]]
[[ 0.07649121  0.96214393  0.87319294  0.71938833  0.46180685]
 [ 0.15165504  0.95737192  1.48672289  1.03733326  0.52585261]
 [ 0.2580701   0.75562397  0.72253546  0.66372978  0.83297043]]
eval_tensor
[[ 0.28282795  0.86633722  0.60257693  0.76006332  0.70923752]
 [ 0.18949885  1.31733272  1.21007392  0.71038713  0.43339924]
 [ 0.16221038  1.15390335  0.88341938  0.41306631  0.43839332]]
[[ 0.29700547  0.79904247  0.53439649  0.74611433  0.55112404]
 [ 0.26163716  1.23206321  1.26263629  0.80220571  0.71174387]
 [ 0.15530735  0.91545366  0.60444962  0.57824888  0.63467022]]
[[ 0.07649121  0.96214393  0.87319294  0.71938833  0.46180685]
 [ 0.15165504  0.95737192  1.48672289  1.03733326  0.52585261]
 [ 0.2580701   0.75562397  0.72253546  0.66372978  0.83297043]]
"""

你可能感兴趣的:(深度学习编程)