我正在寻找在theano中使用fft-convolution的方法。
我用theano写了简单的卷积码。
但是,如果我设置"fft_conv = 1"
,但是此代码不起作用,尽管简单卷积适用于"fft_conv = 0"
请告诉我这段代码有什么问题?
import numpy as np
import theano.sandbox.cuda.fftconv
from theano.tensor.nnet import conv
import theano.tensor as T
xdata_test = np.random.uniform(low=-1, high=1, size=(100,76,76),)
xdata_test = np.asarray(xdata_test,dtype='float32')
CONVfilter = np.random.uniform(low=-1,high=1,size=(10,1,6,6))
CONVfilter = np.asarray(CONVfilter,dtype='float32')
x = T.tensor3('x') # the data is presented as rasterized images
layer0_input = x.reshape((100, 1, 76, 76))
fft_flag = 1
if fft_flag == 1 :
##### FFT-CONVOLUTION VERSION
conv_out = theano.sandbox.cuda.fftconv.conv2d_fft(
input=layer0_input,
filters=CONVfilter,
filter_shape=(10, 1, 6, 6),
image_shape=(100,1,76,76),
border_mode='valid',
pad_last_dim=False
)
elif fft_flag == 0 :
###### CONVENTIONAL CONVOLUTION VERSION
conv_out = conv.conv2d(
input=layer0_input,
filters=CONVfilter,
filter_shape=(10, 1, 6, 6),
image_shape=(100,1,76,76),
)
test_conv = theano.function([x],conv_out)
result = test_conv(xdata_test)
错误信息如下所示:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Anaconda\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 580, in runfile
execfile(filename, namespace)
File "C:/Users/user/Documents/Python Scripts/ffttest.py", line 38, in <module>
result = test_conv(xdata_test)
File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 606, in __call__
storage_map=self.fn.storage_map)
File "C:\Anaconda\lib\site-packages\theano\gof\link.py", line 205, in raise_with_op
'\n' + '\n'.join(hints))
TypeError: __init__() takes at least 3 arguments (2 given)