在火炬的频域中计算卷积的更快选择

时间:2019-03-18 12:26:48

标签: performance profiling fft pytorch convolution

我正在torch中实现自定义卷积,它使用FFT将图像转换到频域,计算内核和图像之间的乘积,然后计算反FFT。尽管有效,但我注意到乘积计算速度很慢。有没有优化的方法?

我向所有内容添加了timer,以查看其运行情况并获得以下结果(在cpu上进行了测试):

squeezes - done in 4.17232513E-05s
real - done in 1.67846680E-04s
im - done in 7.53402710E-05s
stack - done in 8.36849213E-05s
sum - done in 3.96490097E-04s
bias - done in 1.64508820E-05s

这是我的实现方式:

请注意,这里我不计算FFT及其反数。 torch这些操作的实现非常快。

def fconv2d(input, kernel, bias=None):
    # Computes the convolution in the frequency domain given
    # an input of shape (B, Cin, H, W) and kernel of shape (Cout, Cin, H, W).
    # Expects input and kernel already in frequency domain!

    with timer('squeezes'):
        kernel = kernel.unsqueeze(0)
        # Expand kernel to (B, Cout, Cin, H, W)
        # Expand input to (B, Cout, Cin, H, W)
        input = input.unsqueeze(1)
    # Compute the multiplication
    # (a+bj)*(c+dj) = (ac-bd)+(ad+bc)j
    with timer('real'):
        real = input[..., 0] * kernel[..., 0] - \
               input[..., 1] * kernel[..., 1]
    with timer('im'):
        im = input[..., 0] * kernel[..., 1] + \
             input[..., 1] * kernel[..., 0]
    # Stack both channels and sum-reduce the input channels dimension
    with timer('stack'):
        out = torch.stack([real, im], -1)

    with timer('sum'):
        out = out.sum(dim=-4)
    # Add bias
    with timer('bias'):
        if bias is not None:
            bias = bias.expand(1, 1, 1, bias.shape[0]).permute(0, 3, 1, 2)
            out += bias
    return out

0 个答案:

没有答案