Python和PyTorch对比实现卷积convolution函数及反向传播

摘要

本文使用纯 Python 和 PyTorch 对比实现 convolution 函数及其反向传播.

相关

原理和详细解释, 请参考文章 :

卷积convolution函数详解及反向传播中的梯度求导

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

import torch
import numpy as np


class Conv2d:
    def __init__(self, stride=1):
        self.weight = None
        self.bias = None

        self.stride = stride

        self.x = None
        self.dw = None
        self.db = None

        self.input_height = None
        self.input_width = None
        self.weight_height = None
        self.weight_width = None
        self.output_height = None
        self.output_width = None

    def __call__(self, x):
        self.x = x
        self.input_height = np.shape(x)[0]
        self.input_width = np.shape(x)[1]
        self.weight_height = np.shape(self.weight)[0]
        self.weight_width = np.shape(self.weight)[1]

        self.output_height = int((self.input_height - self.weight_height) / self.stride) + 1
        self.output_width = int((self.input_width - self.weight_width) / self.stride) + 1

        out = np.zeros((self.output_height, self.output_width))
        for i in range(self.output_height):
            for j in range(self.output_width):
                for r in range(self.weight_height):
                    for s in range(self.weight_width):
                        out[i, j] += x[i * self.stride + r, j * self.stride + s] * self.weight[r, s]
        out = out + self.bias
        return out

    def backward(self, d_loss):
        dx = np.zeros_like(self.x)
        self.dw = np.zeros_like(self.weight)
        self.db = np.zeros_like(self.bias)

        for i in range(self.output_height):
            for j in range(self.output_width):
                start_i = i * self.stride
                start_j = j * self.stride
                end_i = start_i + self.weight_height
                end_j = start_j + self.weight_width
                dx[start_i: end_i, start_j:end_j] += d_loss[i, j] * self.weight

                for u in range(self.weight_height):
                    for v in range(self.weight_width):
                        self.dw[u, v] += d_loss[i, j] * self.x[start_i + u, start_j + v]

        self.db = np.sum(d_loss)

        return dx


np.set_printoptions(precision=8, suppress=True, linewidth=120)
np.random.seed(123)
torch.random.manual_seed(123)

x_numpy = np.random.random((1, 3, 5, 5))
x_tensor = torch.tensor(x_numpy, requires_grad=True)

conv2d_tensor = torch.nn.Conv2d(3, 1, (3, 3), stride=2).double()

conv2d_numpy_channel_0 = Conv2d(stride=2)
conv2d_numpy_channel_0.weight = conv2d_tensor.weight.data.numpy()[0, 0]
conv2d_numpy_channel_0.bias = conv2d_tensor.bias.data.numpy()[0]

conv2d_numpy_channel_1 = Conv2d(stride=2)
conv2d_numpy_channel_1.weight = conv2d_tensor.weight.data.numpy()[0, 1]
conv2d_numpy_channel_1.bias = conv2d_tensor.bias.data.numpy()[0]

conv2d_numpy_channel_2 = Conv2d(stride=2)
conv2d_numpy_channel_2.weight = conv2d_tensor.weight.data.numpy()[0, 2]
conv2d_numpy_channel_2.bias = conv2d_tensor.bias.data.numpy()[0]

out_numpy_0 = conv2d_numpy_channel_0(x_numpy[0, 0])
out_numpy_1 = conv2d_numpy_channel_1(x_numpy[0, 1])
out_numpy_2 = conv2d_numpy_channel_2(x_numpy[0, 2])
out_numpy = out_numpy_0 + out_numpy_1 + out_numpy_2 - conv2d_numpy_channel_0.bias * 2
out_tensor = conv2d_tensor(x_tensor)

d_loss_numpy = np.random.random(out_tensor.shape)
d_loss_tensor = torch.tensor(d_loss_numpy)

dx_numpy_0 = conv2d_numpy_channel_0.backward(d_loss_numpy[0][0])
dx_numpy_1 = conv2d_numpy_channel_1.backward(d_loss_numpy[0][0])
dx_numpy_2 = conv2d_numpy_channel_2.backward(d_loss_numpy[0][0])

out_tensor.backward(d_loss_tensor)
dx_tensor = x_tensor.grad

dw_numpy_0 = conv2d_numpy_channel_0.dw
dw_numpy_1 = conv2d_numpy_channel_1.dw
dw_numpy_2 = conv2d_numpy_channel_2.dw
dw_tensor = conv2d_tensor.weight.grad

db_numpy = conv2d_numpy_channel_0.db
db_tensor = conv2d_tensor.bias.grad

print("out_numpy \n", out_numpy)
print("out_tensor \n", out_tensor.data.numpy())

print("dx_numpy_0 \n", dx_numpy_0)
print("dx_numpy_1 \n", dx_numpy_1)
print("dx_numpy_2 \n", dx_numpy_2)
print("dx_tensor \n", dx_tensor.data.numpy())

print("dw_numpy_0 \n", dw_numpy_0)
print("dw_numpy_1 \n", dw_numpy_1)
print("dw_numpy_2 \n", dw_numpy_2)
print("dw_tensor \n", dw_tensor.data.numpy())

print("db_numpy \n", db_numpy)
print("db_tensor \n", db_tensor.data.numpy()[0])

"""
代码输出 :
out_numpy 
 [[-0.01129476 -0.12804462]
 [-0.04662733 -0.12602237]]
out_tensor 
 [[[[-0.01129476 -0.12804462]
   [-0.04662733 -0.12602237]]]]
dx_numpy_0 
 [[-0.04664896  0.0037894  -0.10051156  0.00354941 -0.05321863]
 [ 0.04314122 -0.09747379  0.1242681  -0.09130056  0.07854812]
 [-0.09562402 -0.08993824 -0.17737642 -0.08421562 -0.08233961]
 [ 0.01153658 -0.0260659   0.03353431 -0.02510021  0.02159433]
 [-0.02223535 -0.02432176 -0.04074214 -0.02342069 -0.0186144 ]]
dx_numpy_1 
 [[ 0.05181031 -0.04226934  0.09133848 -0.03959233  0.04009821]
 [-0.09709334 -0.06940826 -0.13294859 -0.06501248 -0.03934415]
 [-0.00862673 -0.09857383  0.07863871 -0.09262803  0.08118677]
 [-0.02596416 -0.01856077 -0.03623482 -0.01787313 -0.01081643]
 [-0.00601189 -0.02333736  0.01424206 -0.02247276  0.01928911]]
dx_numpy_2 
 [[-0.02697872  0.03672051  0.05562605  0.03439492  0.07577281]
 [ 0.02131312  0.03128234  0.13038743  0.02930116  0.10343071]
 [-0.05880928  0.04605546 -0.08453932  0.04339676 -0.02684292]
 [ 0.00569944  0.00836535  0.03501728  0.00805543  0.02843501]
 [-0.01379719  0.00968999 -0.02689682  0.009331   -0.01310653]]
dx_tensor 
 [[[[-0.04664896  0.0037894  -0.10051156  0.00354941 -0.05321863]
   [ 0.04314122 -0.09747379  0.1242681  -0.09130056  0.07854812]
   [-0.09562402 -0.08993824 -0.17737642 -0.08421562 -0.08233961]
   [ 0.01153658 -0.0260659   0.03353431 -0.02510021  0.02159433]
   [-0.02223535 -0.02432176 -0.04074214 -0.02342069 -0.0186144 ]]

  [[ 0.05181031 -0.04226934  0.09133848 -0.03959233  0.04009821]
   [-0.09709334 -0.06940826 -0.13294859 -0.06501248 -0.03934415]
   [-0.00862673 -0.09857383  0.07863871 -0.09262803  0.08118677]
   [-0.02596416 -0.01856077 -0.03623482 -0.01787313 -0.01081643]
   [-0.00601189 -0.02333736  0.01424206 -0.02247276  0.01928911]]

  [[-0.02697872  0.03672051  0.05562605  0.03439492  0.07577281]
   [ 0.02131312  0.03128234  0.13038743  0.02930116  0.10343071]
   [-0.05880928  0.04605546 -0.08453932  0.04339676 -0.02684292]
   [ 0.00569944  0.00836535  0.03501728  0.00805543  0.02843501]
   [-0.01379719  0.00968999 -0.02689682  0.009331   -0.01310653]]]]
dw_numpy_0 
 [[ 0.66199495  0.60207865  0.66608153]
 [ 0.776979    0.96114693  0.73470673]
 [ 0.65992339  0.69515322  0.70807041]]
dw_numpy_1 
 [[ 0.50545913  0.59089005  0.70583433]
 [ 0.44239851  0.61455868  0.67625654]
 [ 0.87371632  0.93543194  1.06095454]]
dw_numpy_2 
 [[ 0.60952866  0.99110438  0.77759622]
 [ 0.61064005  0.88645726  0.8155419 ]
 [ 0.85702622  0.97542119  0.94264254]]
dw_tensor 
 [[[[ 0.66199495  0.60207865  0.66608153]
   [ 0.776979    0.96114693  0.73470673]
   [ 0.65992339  0.69515322  0.70807041]]

  [[ 0.50545913  0.59089005  0.70583433]
   [ 0.44239851  0.61455868  0.67625654]
   [ 0.87371632  0.93543194  1.06095454]]

  [[ 0.60952866  0.99110438  0.77759622]
   [ 0.61064005  0.88645726  0.8155419 ]
   [ 0.85702622  0.97542119  0.94264254]]]]
db_numpy 
 1.46324723111
db_tensor 
 1.46324723111
"""

全文完

你可能感兴趣的:(深度学习编程)