如何在PyTorch中正确使用Numpy的FFT函数?

时间:2017-08-17 23:46:17

标签: python numpy machine-learning fft pytorch

我最近被介绍给PyTorch,并开始浏览该库的文档和教程。 在“使用numpy和scipy创建扩展”教程中( http://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html),在“无参数示例”下,使用名为“BadFFTFunction”的numpy创建示例函数。

功能说明:

  

“这一层并没有特别有用或数学上做任何事情   正确的。

     

它被恰当地命名为BadFFTFunction“

该功能及其用法如下:

from numpy.fft import rfft2, irfft2

class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return torch.FloatTensor(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return torch.FloatTensor(result)

def incorrect_fft(input):
    return BadFFTFunction()(input)

input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)

不幸的是,我最近才被引入信号处理,并且不确定此函数中可能存在的(可能是明显的)错误。

我想知道,如何修复此功能以使其前向和后向输出正确? 如何修复BadFFTFunction,以便在PyTorch中使用可微分的FFT函数?

任何帮助将不胜感激。 谢谢。

2 个答案:

答案 0 :(得分:1)

我认为错误是:首先,尽管名称中包含FFT,但该函数仅返回FFT输出的幅度/绝对值,而不是完整的复系数。另外,仅使用逆FFT计算幅度的梯度可能在数学上没有多大意义(?)。

有一个名为pytorch-fft的软件包试图在pytorch中提供FFT函数。您可以看到一些autograd功能的实验代码here。另请注意此issue中的讨论。

答案 1 :(得分:0)

从 1.8 版开始,PyTorch 具有 torch.fft

torch.fft.fft(input)