该9步过程可以看作是一个矩阵
跟另外一个矩阵
的矩阵相乘
将每步input
拉直,例如左上角第一个深蓝色区域拉为332001312
行向量,再将Kernel
拉为$(012220012)^T$列向量相乘得到第一个最终结果的左上角第一个数值12.0
So,上边九步运算可以视为行数为9的矩阵和列数为9的矩阵进行矩阵乘法,再将相乘的结果reshape
为欲得到的结果
Other Method
实现一个长度为25
的内积
目前Kernel
只是3×3
的范围大小,若是将Kernel
填充一下,eg.
左上角第一幅图,Kernel
只有9
个数,但是我们可以脑补一下浅蓝色部分都填充为0
,每一步都将Kernel
填充浅蓝色部分为0
,那么,原问题就变为25
行向量和25
列向量相乘。
Coding
希望把region_vector都放入到region_matrix中,再将region_matrix与Kernel_matrix相乘
input = torch.randn(5, 5) # 卷积输入特征图
kernel = torch.randn(3, 3) # 卷积核
bias = torch.randn(1) # 卷积偏置项,默认输出通道数目=1
# Func2 input flatten 版本
def matrix_mutiplication_for_conv2d_flatten(input, kernel, bias=0, stride=1, padding=0):
if padding > 0:
input = F.pad(input, (padding, padding, padding, padding)) # 左右上下都pad0
input_h, input_w = input.shape
kernel_h, kernel_w = kernel.shape
output_w = (math.floor((input_w - kernel_w) / stride) + 1) # 卷积输出的宽度
output_h = (math.floor((input_h - kernel_h) / stride) + 1) # 卷积输出的高度
output = torch.zeros(output_h, output_w) # 初始化输出矩阵
# numel()返回tensor中元素数目
# 上图所示一共9行9列
region_matrix = torch.zeros(output.numel(), kernel.numel()) # 定义region_matrix 存储所有展平后的特征区域
kernel_matrix = kernel.reshape((kernel.numel(), 1)) # 定义kernel_matrix 存储kernel的列向量形式
roll_index = 0
for i in range(0, input_h - kernel_h + 1, stride): # 对高度维度进行遍历
for j in range(0, input_w - kernel_w + 1, stride): # 对宽度维度进行遍历
region = input[i:i+kernel_h, j:j+kernel_w] # 取出被核滑动到的区域 3*3=9
region_vector = torch.flatten(region) # 将该区域展平 3*3=>1*9
region_matrix[roll_index] = region_vector # 每次放入一行1*9,9次将region_matrix全部铺满
roll_index += 1 # 下次再下一行赋值
output_matrix = region_matrix @ kernel_matrix # 两个矩阵相乘,就是对9次的input滑动窗口中的3*3展平后的1*9行向量✖kernel的列向量相乘9次加起来
output = output_matrix.reshape((output_h, output_w)) + bias
return output
# 矩阵运算实现卷积的结果 flatten版本
mat_mul_conv_flatten_output = matrix_mutiplication_for_conv2d_flatten(input, kernel, bias=bias, padding=1)
print(mat_mul_conv_flatten_output.shape, "\n", mat_mul_conv_flatten_output)
print(pytorch_api_conv_output.squeeze(0).squeeze(0).shape, "\n", pytorch_api_conv_output.squeeze(0).squeeze(0))
'''
torch.Size([5, 5])
tensor([[-0.2212, -5.2747, -2.2503, -5.2114, -3.0707],
[ 0.4797, -0.0270, 4.1195, 4.6948, 0.7474],
[ 0.4497, 0.1129, -1.5966, -6.6838, -3.4131],
[-1.9987, -1.4590, 1.0481, 1.0733, 2.2200],
[ 0.4015, 2.9796, -4.8810, -3.2139, -2.6157]])
torch.Size([5, 5])
tensor([[-0.2212, -5.2747, -2.2503, -5.2114, -3.0707],
[ 0.4797, -0.0270, 4.1195, 4.6948, 0.7474],
[ 0.4497, 0.1129, -1.5966, -6.6838, -3.4131],
[-1.9987, -1.4590, 1.0481, 1.0733, 2.2200],
[ 0.4015, 2.9796, -4.8810, -3.2139, -2.6157]])
'''
还不快抢沙发