
该9步过程可以看作是一个矩阵
跟另外一个矩阵
的矩阵相乘
将每步input
拉直,例如左上角第一个深蓝色区域拉为332001312
行向量,再将Kernel
拉为$(012220012)^T$列向量相乘得到第一个最终结果的左上角第一个数值12.0
So,上边九步运算可以视为行数为9的矩阵和列数为9的矩阵进行矩阵乘法,再将相乘的结果reshape
为欲得到的结果
Other Method
实现一个长度为25
的内积
目前Kernel
只是3×3
的范围大小,若是将Kernel
填充一下,eg.
左上角第一幅图,Kernel
只有9
个数,但是我们可以脑补一下浅蓝色部分都填充为0
,每一步都将Kernel
填充浅蓝色部分为0
,那么,原问题就变为25
行向量和25
列向量相乘。
Coding
希望把region_vector都放入到region_matrix中,再将region_matrix与Kernel_matrix相乘
input = torch.randn(5, 5)
kernel = torch.randn(3, 3)
bias = torch.randn(1)
def matrix_mutiplication_for_conv2d_flatten(input, kernel, bias=0, stride=1, padding=0):
if padding > 0:
input = F.pad(input, (padding, padding, padding, padding))
input_h, input_w = input.shape
kernel_h, kernel_w = kernel.shape
output_w = (math.floor((input_w - kernel_w) / stride) + 1)
output_h = (math.floor((input_h - kernel_h) / stride) + 1)
output = torch.zeros(output_h, output_w)
region_matrix = torch.zeros(output.numel(), kernel.numel())
kernel_matrix = kernel.reshape((kernel.numel(), 1))
roll_index = 0
for i in range(0, input_h - kernel_h + 1, stride):
for j in range(0, input_w - kernel_w + 1, stride):
region = input[i:i+kernel_h, j:j+kernel_w]
region_vector = torch.flatten(region)
region_matrix[roll_index] = region_vector
roll_index += 1
output_matrix = region_matrix @ kernel_matrix
output = output_matrix.reshape((output_h, output_w)) + bias
return output
mat_mul_conv_flatten_output = matrix_mutiplication_for_conv2d_flatten(input, kernel, bias=bias, padding=1)
print(mat_mul_conv_flatten_output.shape, "\n", mat_mul_conv_flatten_output)
print(pytorch_api_conv_output.squeeze(0).squeeze(0).shape, "\n", pytorch_api_conv_output.squeeze(0).squeeze(0))
'''
torch.Size([5, 5])
tensor([[-0.2212, -5.2747, -2.2503, -5.2114, -3.0707],
[ 0.4797, -0.0270, 4.1195, 4.6948, 0.7474],
[ 0.4497, 0.1129, -1.5966, -6.6838, -3.4131],
[-1.9987, -1.4590, 1.0481, 1.0733, 2.2200],
[ 0.4015, 2.9796, -4.8810, -3.2139, -2.6157]])
torch.Size([5, 5])
tensor([[-0.2212, -5.2747, -2.2503, -5.2114, -3.0707],
[ 0.4797, -0.0270, 4.1195, 4.6948, 0.7474],
[ 0.4497, 0.1129, -1.5966, -6.6838, -3.4131],
[-1.9987, -1.4590, 1.0481, 1.0733, 2.2200],
[ 0.4015, 2.9796, -4.8810, -3.2139, -2.6157]])
'''
还不快抢沙发