im2col操作是用来优化卷积运算,它的核心是将卷积核感受野的转化成一行(列)来存储,优化运算速度,减少内存访问时间。
caffe等框架中使用了这种计算方式(im2col + gemm矩阵运算)来优化卷积计算:
示意图如下所示:
将滑动窗口事先转变为一个矩阵
深度学习的应用
批处理中的维度约定:
输入数据中,n=3,c=3(同时处理三个 三通道的图像数据)
卷积核: n=3,c=3
输出:n=3,c=1
如果是四维数据,则形状为(10,3,28,28)代表了10个三通道的宽高为28的图像数据合并成一个tensor
im2col 操作示意图:
对于输入图像,首先将图像转化为一个 巨大的矩阵:
对于卷积核,将卷积核展开处理:
将输入图像得到的矩阵与卷积核展开后的列做矩阵乘法(可以大批量加速),得到结果后进行col2im操作复原结果:
im2col 代码展示:
思路:将卷积核的感受野部分放置到一行
def im2col(input_data,ksize,stride=1,pad=0):
"""
----------
input_data : 由(数据量, 通道, 高, 长)的4维数组构成的输入数据
ksize :卷积核大小
stride : 步幅
pad : 填充
Returns
-------
col : 2维数组
"""
N,C,H,W = input_data.shape
out_h = (H + 2 * pad - ksize)//stride + 1
out_w = (W + 2 * pad - ksize)//stride + 1
#填充padiing
#只填充H,W维度的
img = np.pad(input_data,[(0,0), (0,0), (pad, pad), (pad, pad)],'constant')
##最后的shape为 out_h,out_w 表示方便取同一行的感受野,然后reshape的时候不同列变为同一行
col = np.zeros((N,C,ksize,ksize,out_h,out_w))
for y in range(ksize):
y_max = y + stride * out_h
for x in range(ksize):
##一次性取out_w *stride 个数放入新矩阵中
x_max = x + stride * out_w
col[:,:,y,x,:,:]=img[:,:,y:y_max,x:x_max]
#transpose 操作很关键
#原来shape是(N,C,filter_h,filter_w,out_h,out_w)
#现在保持N,out_w,out_h不变,将C,filter_w,filter_h的值放都后面维度,相当于把卷积感受野部分全放后面维度
#然后reshape 操作把channel,filter_h,filter_w(卷积感受野部分) 规整成一行,方便直接与卷积做矩阵乘法
col = col.transpose(0,4,5,1,2,3).reshape(N*out_h*out_w, -1)
return col
例子验证:
假如我们有一个1-16的4*4矩阵,padding=1,
,
则卷积核为3*3的卷积,则,它每次滑动过的感受野为:
我们将其依次拼接,得到矩阵:
每一行为一个感受野;
>>> A = np.arange(1,17).reshape(4,4)
>>> input_img = A.reshape(1,1,4,4)
>>> col = im2col(input_img,3,stride=1,pad=1)
>>> col
array([[ 0., 0., 0., 0., 1., 2., 0., 5., 6.],
[ 0., 0., 0., 1., 2., 3., 5., 6., 7.],
[ 0., 0., 0., 2., 3., 4., 6., 7., 8.],
[ 0., 0., 0., 3., 4., 0., 7., 8., 0.],
[ 0., 1., 2., 0., 5., 6., 0., 9., 10.],
[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
[ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
[ 3., 4., 0., 7., 8., 0., 11., 12., 0.],
[ 0., 5., 6., 0., 9., 10., 0., 13., 14.],
[ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
[ 6., 7., 8., 10., 11., 12., 14., 15., 16.],
[ 7., 8., 0., 11., 12., 0., 15., 16., 0.],
[ 0., 9., 10., 0., 13., 14., 0., 0., 0.],
[ 9., 10., 11., 13., 14., 15., 0., 0., 0.],
[10., 11., 12., 14., 15., 16., 0., 0., 0.],
[11., 12., 0., 15., 16., 0., 0., 0., 0.]])