此篇博客为作者对《深度学习入门——基于Python的理论和实现》一书中的im2col函数的笔记
1 学习该函数前需要具备什么知识?
需要了解卷积的相关知识
2 im2col函数的作用(为什么要使用im2co函数)?
卷积运算需要好几层for语句,这样实现麻烦,而且numpy中存在使用for语句变慢的缺点,所以不使用for语句,而用im2col函数代替。
im2col全称image to column(从图像到矩阵),作用为加速卷积运算。即把包含批数量的4维数据转换成2维数据。(也就是将输入数据降维,然后通过numpy的矩阵运算后得到结果,再将结果的形状还原,从而通过用矩阵运算来代替for循环语句)
如图所示,输入数据为四维(N, C, H, W),分别为数据个数, 通道数, 高, 长
滤波器为四维(FN, C, FH, FW),分别为滤波器个数, 通道数, 滤波器高, 滤波器长
输出数据为四维(N, FN, OH, OW),分别为数据个数, 通道数, 高, 长
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
input_data : 由(数据量, 通道, 高, 长)的4维数组构成的输入数据
filter_h : 滤波器的高
filter_w : 滤波器的长
stride : 步幅
pad : 填充
Returns
-------
col : 2维数组
"""
N, C, H, W = input_data.shape #数据量、 通道、 高、 宽
#输出数据高宽可由以下公式计算,不懂的同学可以去看卷积的相关知识(就是一个数学公式而已,记住就好)
out_h = (H + 2*pad - filter_h)//stride + 1 #输出数据高 //表示向下取整除法 例:3//2=1
out_w = (W + 2*pad - filter_w)//stride + 1 #输出数据宽
#np.pad教程: https://blog.csdn.net/hustqb/article/details/77726660
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
#初始化一个六维数组 (数据个数, 通道数, 滤波器高, 滤波器宽, 输出高, 输出宽)
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w)) #六维
#循环产生col数组,具体解释见代码下面
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
#transpose详解 https://blog.csdn.net/u012762410/article/details/78912667
#就是0维的不动,第1维的换到第3维,第2维的换到第4维, 第3维的换到第5维,第4维的换到第1维,第5维的换到第2维
col = col.transpose(0, 4, 5, 1, 2, 3)
#reshape变为2维数据,这样就可以不用循环,直接进行矩阵运算了
#-1意思为固定第一维,第二维自动生成 例如:数据共20个,N*out_h*out_w是4,则第一维是4, 第二维是5
col = col.reshape(N*out_h*out_w, -1)
return col
对于此句:
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
我们先来看一个例子:
col = np.zeros((1,2,3))
print(col)
col[:, 1, :] = np.array([[8,88,888]])
print(col)
输出:
[[[0. 0. 0.]
[0. 0. 0.]]]
[[[ 0. 0. 0.]
[ 8. 88. 888.]]]
由此可得出col[:, 1, :]的含义为把[[8, 88, 888]]赋给col[][1][], 也就是把col第二维索引值为1的数组(此例中为[0,0,0]),更改为一个尺寸正好为col第一维和第三维的数组(此例为[8,88,888])
所以如果我们继续执行如下代码:
col[:, 0, :] = np.array([[9,99,999]])
可以预见输出为:
[[[ 9. 99. 999.]
[ 8. 88. 888.]]]
所以回到上面那一句代码:
col数组为六维,img数组为四维,固定col数组的第三维,第四维为y, x, img数组的四维与col数组的第1维,第2维,第5维,第6维相对应。
y:y_max:stride意思可以理解为(y_max-y)/stride,也就等于out_h
所以这段循环实现的功能就是把输入数据按照滤波器的尺寸进行分割。
比如我们如果传入im2col函数的数据为:
尺寸为1,2,4,4(数据个数, 通道数, 高, 宽)
同时把传入im2col函数的滤波器设置为3*3尺寸
所以我们可以得知输出的数据尺寸一定是1*2*2*2(数据个数和通道数不变,高和宽由卷积运算得出)
x1 = np.array([
[[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]],
[[111, 112, 113, 114],
[115, 116, 117, 118],
[119, 120, 121, 122],
[123, 124, 125, 126]]]
])
则下段代码的输出为:
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
print('col',y,x,':\n',col)
col 0 0 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 0 1 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 0 2 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 3. 4.]
[ 7. 8.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[113. 114.]
[117. 118.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 1 0 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 3. 4.]
[ 7. 8.]]]
[[[ 5. 6.]
[ 9. 10.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[113. 114.]
[117. 118.]]]
[[[115. 116.]
[119. 120.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 1 1 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 3. 4.]
[ 7. 8.]]]
[[[ 5. 6.]
[ 9. 10.]]
[[ 6. 7.]
[ 10. 11.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[113. 114.]
[117. 118.]]]
[[[115. 116.]
[119. 120.]]
[[116. 117.]
[120. 121.]]
[[ 0. 0.]
[ 0. 0.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 1 2 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 3. 4.]
[ 7. 8.]]]
[[[ 5. 6.]
[ 9. 10.]]
[[ 6. 7.]
[ 10. 11.]]
[[ 7. 8.]
[ 11. 12.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[113. 114.]
[117. 118.]]]
[[[115. 116.]
[119. 120.]]
[[116. 117.]
[120. 121.]]
[[117. 118.]
[121. 122.]]]
[[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 2 0 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 3. 4.]
[ 7. 8.]]]
[[[ 5. 6.]
[ 9. 10.]]
[[ 6. 7.]
[ 10. 11.]]
[[ 7. 8.]
[ 11. 12.]]]
[[[ 9. 10.]
[ 13. 14.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[113. 114.]
[117. 118.]]]
[[[115. 116.]
[119. 120.]]
[[116. 117.]
[120. 121.]]
[[117. 118.]
[121. 122.]]]
[[[119. 120.]
[123. 124.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 2 1 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 3. 4.]
[ 7. 8.]]]
[[[ 5. 6.]
[ 9. 10.]]
[[ 6. 7.]
[ 10. 11.]]
[[ 7. 8.]
[ 11. 12.]]]
[[[ 9. 10.]
[ 13. 14.]]
[[ 10. 11.]
[ 14. 15.]]
[[ 0. 0.]
[ 0. 0.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[113. 114.]
[117. 118.]]]
[[[115. 116.]
[119. 120.]]
[[116. 117.]
[120. 121.]]
[[117. 118.]
[121. 122.]]]
[[[119. 120.]
[123. 124.]]
[[120. 121.]
[124. 125.]]
[[ 0. 0.]
[ 0. 0.]]]]]]
col 2 2 :
[[[[[[ 1. 2.]
[ 5. 6.]]
[[ 2. 3.]
[ 6. 7.]]
[[ 3. 4.]
[ 7. 8.]]]
[[[ 5. 6.]
[ 9. 10.]]
[[ 6. 7.]
[ 10. 11.]]
[[ 7. 8.]
[ 11. 12.]]]
[[[ 9. 10.]
[ 13. 14.]]
[[ 10. 11.]
[ 14. 15.]]
[[ 11. 12.]
[ 15. 16.]]]]
[[[[111. 112.]
[115. 116.]]
[[112. 113.]
[116. 117.]]
[[113. 114.]
[117. 118.]]]
[[[115. 116.]
[119. 120.]]
[[116. 117.]
[120. 121.]]
[[117. 118.]
[121. 122.]]]
[[[119. 120.]
[123. 124.]]
[[120. 121.]
[124. 125.]]
[[121. 122.]
[125. 126.]]]]]]
由此可见,学习深度学习,最好能提前学习Python和numpy的知识点,不然会很浪费时间!!