我最近被介绍给PyTorch,并开始浏览该库的文档和教程。 在“使用numpy和scipy创建扩展”教程中( http://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html),在“无参数示例”下,使用名为“BadFFTFunction”的numpy创建示例函数。
功能说明:
“这一层并没有特别有用或数学上做任何事情 正确的。
它被恰当地命名为BadFFTFunction“
该功能及其用法如下:
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return torch.FloatTensor(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return torch.FloatTensor(result)
def incorrect_fft(input):
return BadFFTFunction()(input)
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)
不幸的是,我最近才被引入信号处理,并且不确定此函数中可能存在的(可能是明显的)错误。
我想知道,如何修复此功能以使其前向和后向输出正确? 如何修复BadFFTFunction,以便在PyTorch中使用可微分的FFT函数?
任何帮助将不胜感激。 谢谢。
答案 0 :(得分:1)
我认为错误是:首先,尽管名称中包含FFT,但该函数仅返回FFT输出的幅度/绝对值,而不是完整的复系数。另外,仅使用逆FFT计算幅度的梯度可能在数学上没有多大意义(?)。
有一个名为pytorch-fft的软件包试图在pytorch中提供FFT函数。您可以看到一些autograd功能的实验代码here。另请注意此issue中的讨论。
答案 1 :(得分:0)
从 1.8 版开始,PyTorch 具有 torch.fft
:
torch.fft.fft(input)