我正在torch
中实现自定义卷积,它使用FFT
将图像转换到频域,计算内核和图像之间的乘积,然后计算反FFT
。尽管有效,但我注意到乘积计算速度很慢。有没有优化的方法?
我向所有内容添加了timer
,以查看其运行情况并获得以下结果(在cpu
上进行了测试):
squeezes - done in 4.17232513E-05s
real - done in 1.67846680E-04s
im - done in 7.53402710E-05s
stack - done in 8.36849213E-05s
sum - done in 3.96490097E-04s
bias - done in 1.64508820E-05s
这是我的实现方式:
请注意,这里我不计算FFT
及其反数。 torch
这些操作的实现非常快。
def fconv2d(input, kernel, bias=None):
# Computes the convolution in the frequency domain given
# an input of shape (B, Cin, H, W) and kernel of shape (Cout, Cin, H, W).
# Expects input and kernel already in frequency domain!
with timer('squeezes'):
kernel = kernel.unsqueeze(0)
# Expand kernel to (B, Cout, Cin, H, W)
# Expand input to (B, Cout, Cin, H, W)
input = input.unsqueeze(1)
# Compute the multiplication
# (a+bj)*(c+dj) = (ac-bd)+(ad+bc)j
with timer('real'):
real = input[..., 0] * kernel[..., 0] - \
input[..., 1] * kernel[..., 1]
with timer('im'):
im = input[..., 0] * kernel[..., 1] + \
input[..., 1] * kernel[..., 0]
# Stack both channels and sum-reduce the input channels dimension
with timer('stack'):
out = torch.stack([real, im], -1)
with timer('sum'):
out = out.sum(dim=-4)
# Add bias
with timer('bias'):
if bias is not None:
bias = bias.expand(1, 1, 1, bias.shape[0]).permute(0, 3, 1, 2)
out += bias
return out