本文使用纯 Python 和 PyTorch 对比实现 convolution 函数及其反向传播.
原理和详细解释, 请参考文章 :
卷积convolution函数详解及反向传播中的梯度求导
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
import torch
import numpy as np
class Conv2d:
def __init__(self, stride=1):
self.weight = None
self.bias = None
self.stride = stride
self.x = None
self.dw = None
self.db = None
self.input_height = None
self.input_width = None
self.weight_height = None
self.weight_width = None
self.output_height = None
self.output_width = None
def __call__(self, x):
self.x = x
self.input_height = np.shape(x)[0]
self.input_width = np.shape(x)[1]
self.weight_height = np.shape(self.weight)[0]
self.weight_width = np.shape(self.weight)[1]
self.output_height = int((self.input_height - self.weight_height) / self.stride) + 1
self.output_width = int((self.input_width - self.weight_width) / self.stride) + 1
out = np.zeros((self.output_height, self.output_width))
for i in range(self.output_height):
for j in range(self.output_width):
for r in range(self.weight_height):
for s in range(self.weight_width):
out[i, j] += x[i * self.stride + r, j * self.stride + s] * self.weight[r, s]
out = out + self.bias
return out
def backward(self, d_loss):
dx = np.zeros_like(self.x)
self.dw = np.zeros_like(self.weight)
self.db = np.zeros_like(self.bias)
for i in range(self.output_height):
for j in range(self.output_width):
start_i = i * self.stride
start_j = j * self.stride
end_i = start_i + self.weight_height
end_j = start_j + self.weight_width
dx[start_i: end_i, start_j:end_j] += d_loss[i, j] * self.weight
for u in range(self.weight_height):
for v in range(self.weight_width):
self.dw[u, v] += d_loss[i, j] * self.x[start_i + u, start_j + v]
self.db = np.sum(d_loss)
return dx
np.set_printoptions(precision=8, suppress=True, linewidth=120)
np.random.seed(123)
torch.random.manual_seed(123)
x_numpy = np.random.random((1, 3, 5, 5))
x_tensor = torch.tensor(x_numpy, requires_grad=True)
conv2d_tensor = torch.nn.Conv2d(3, 1, (3, 3), stride=2).double()
conv2d_numpy_channel_0 = Conv2d(stride=2)
conv2d_numpy_channel_0.weight = conv2d_tensor.weight.data.numpy()[0, 0]
conv2d_numpy_channel_0.bias = conv2d_tensor.bias.data.numpy()[0]
conv2d_numpy_channel_1 = Conv2d(stride=2)
conv2d_numpy_channel_1.weight = conv2d_tensor.weight.data.numpy()[0, 1]
conv2d_numpy_channel_1.bias = conv2d_tensor.bias.data.numpy()[0]
conv2d_numpy_channel_2 = Conv2d(stride=2)
conv2d_numpy_channel_2.weight = conv2d_tensor.weight.data.numpy()[0, 2]
conv2d_numpy_channel_2.bias = conv2d_tensor.bias.data.numpy()[0]
out_numpy_0 = conv2d_numpy_channel_0(x_numpy[0, 0])
out_numpy_1 = conv2d_numpy_channel_1(x_numpy[0, 1])
out_numpy_2 = conv2d_numpy_channel_2(x_numpy[0, 2])
out_numpy = out_numpy_0 + out_numpy_1 + out_numpy_2 - conv2d_numpy_channel_0.bias * 2
out_tensor = conv2d_tensor(x_tensor)
d_loss_numpy = np.random.random(out_tensor.shape)
d_loss_tensor = torch.tensor(d_loss_numpy)
dx_numpy_0 = conv2d_numpy_channel_0.backward(d_loss_numpy[0][0])
dx_numpy_1 = conv2d_numpy_channel_1.backward(d_loss_numpy[0][0])
dx_numpy_2 = conv2d_numpy_channel_2.backward(d_loss_numpy[0][0])
out_tensor.backward(d_loss_tensor)
dx_tensor = x_tensor.grad
dw_numpy_0 = conv2d_numpy_channel_0.dw
dw_numpy_1 = conv2d_numpy_channel_1.dw
dw_numpy_2 = conv2d_numpy_channel_2.dw
dw_tensor = conv2d_tensor.weight.grad
db_numpy = conv2d_numpy_channel_0.db
db_tensor = conv2d_tensor.bias.grad
print("out_numpy \n", out_numpy)
print("out_tensor \n", out_tensor.data.numpy())
print("dx_numpy_0 \n", dx_numpy_0)
print("dx_numpy_1 \n", dx_numpy_1)
print("dx_numpy_2 \n", dx_numpy_2)
print("dx_tensor \n", dx_tensor.data.numpy())
print("dw_numpy_0 \n", dw_numpy_0)
print("dw_numpy_1 \n", dw_numpy_1)
print("dw_numpy_2 \n", dw_numpy_2)
print("dw_tensor \n", dw_tensor.data.numpy())
print("db_numpy \n", db_numpy)
print("db_tensor \n", db_tensor.data.numpy()[0])
"""
代码输出 :
out_numpy
[[-0.01129476 -0.12804462]
[-0.04662733 -0.12602237]]
out_tensor
[[[[-0.01129476 -0.12804462]
[-0.04662733 -0.12602237]]]]
dx_numpy_0
[[-0.04664896 0.0037894 -0.10051156 0.00354941 -0.05321863]
[ 0.04314122 -0.09747379 0.1242681 -0.09130056 0.07854812]
[-0.09562402 -0.08993824 -0.17737642 -0.08421562 -0.08233961]
[ 0.01153658 -0.0260659 0.03353431 -0.02510021 0.02159433]
[-0.02223535 -0.02432176 -0.04074214 -0.02342069 -0.0186144 ]]
dx_numpy_1
[[ 0.05181031 -0.04226934 0.09133848 -0.03959233 0.04009821]
[-0.09709334 -0.06940826 -0.13294859 -0.06501248 -0.03934415]
[-0.00862673 -0.09857383 0.07863871 -0.09262803 0.08118677]
[-0.02596416 -0.01856077 -0.03623482 -0.01787313 -0.01081643]
[-0.00601189 -0.02333736 0.01424206 -0.02247276 0.01928911]]
dx_numpy_2
[[-0.02697872 0.03672051 0.05562605 0.03439492 0.07577281]
[ 0.02131312 0.03128234 0.13038743 0.02930116 0.10343071]
[-0.05880928 0.04605546 -0.08453932 0.04339676 -0.02684292]
[ 0.00569944 0.00836535 0.03501728 0.00805543 0.02843501]
[-0.01379719 0.00968999 -0.02689682 0.009331 -0.01310653]]
dx_tensor
[[[[-0.04664896 0.0037894 -0.10051156 0.00354941 -0.05321863]
[ 0.04314122 -0.09747379 0.1242681 -0.09130056 0.07854812]
[-0.09562402 -0.08993824 -0.17737642 -0.08421562 -0.08233961]
[ 0.01153658 -0.0260659 0.03353431 -0.02510021 0.02159433]
[-0.02223535 -0.02432176 -0.04074214 -0.02342069 -0.0186144 ]]
[[ 0.05181031 -0.04226934 0.09133848 -0.03959233 0.04009821]
[-0.09709334 -0.06940826 -0.13294859 -0.06501248 -0.03934415]
[-0.00862673 -0.09857383 0.07863871 -0.09262803 0.08118677]
[-0.02596416 -0.01856077 -0.03623482 -0.01787313 -0.01081643]
[-0.00601189 -0.02333736 0.01424206 -0.02247276 0.01928911]]
[[-0.02697872 0.03672051 0.05562605 0.03439492 0.07577281]
[ 0.02131312 0.03128234 0.13038743 0.02930116 0.10343071]
[-0.05880928 0.04605546 -0.08453932 0.04339676 -0.02684292]
[ 0.00569944 0.00836535 0.03501728 0.00805543 0.02843501]
[-0.01379719 0.00968999 -0.02689682 0.009331 -0.01310653]]]]
dw_numpy_0
[[ 0.66199495 0.60207865 0.66608153]
[ 0.776979 0.96114693 0.73470673]
[ 0.65992339 0.69515322 0.70807041]]
dw_numpy_1
[[ 0.50545913 0.59089005 0.70583433]
[ 0.44239851 0.61455868 0.67625654]
[ 0.87371632 0.93543194 1.06095454]]
dw_numpy_2
[[ 0.60952866 0.99110438 0.77759622]
[ 0.61064005 0.88645726 0.8155419 ]
[ 0.85702622 0.97542119 0.94264254]]
dw_tensor
[[[[ 0.66199495 0.60207865 0.66608153]
[ 0.776979 0.96114693 0.73470673]
[ 0.65992339 0.69515322 0.70807041]]
[[ 0.50545913 0.59089005 0.70583433]
[ 0.44239851 0.61455868 0.67625654]
[ 0.87371632 0.93543194 1.06095454]]
[[ 0.60952866 0.99110438 0.77759622]
[ 0.61064005 0.88645726 0.8155419 ]
[ 0.85702622 0.97542119 0.94264254]]]]
db_numpy
1.46324723111
db_tensor
1.46324723111
"""
全文完