你可以修改fft-convolution的theano代码吗?

时间:2015-09-09 05:15:34

标签: fft convolution theano

我正在寻找在theano中使用fft-convolution的方法。 我用theano写了简单的卷积码。 但是,如果我设置"fft_conv = 1",但是此代码不起作用,尽管简单卷积适用于"fft_conv = 0"

请告诉我这段代码有什么问题?

import numpy as np
import theano.sandbox.cuda.fftconv
from theano.tensor.nnet import conv
import theano.tensor as T

xdata_test = np.random.uniform(low=-1, high=1, size=(100,76,76),)     
xdata_test = np.asarray(xdata_test,dtype='float32')
CONVfilter = np.random.uniform(low=-1,high=1,size=(10,1,6,6)) 
CONVfilter = np.asarray(CONVfilter,dtype='float32')


x = T.tensor3('x')   # the data is presented as rasterized images
layer0_input = x.reshape((100, 1, 76, 76))   

fft_flag = 1
if fft_flag == 1 : 
    #####  FFT-CONVOLUTION VERSION
    conv_out = theano.sandbox.cuda.fftconv.conv2d_fft(
    input=layer0_input,
    filters=CONVfilter,
    filter_shape=(10, 1, 6, 6),
    image_shape=(100,1,76,76),
    border_mode='valid',
    pad_last_dim=False
    )

elif fft_flag == 0     :
    ###### CONVENTIONAL CONVOLUTION VERSION    
    conv_out = conv.conv2d(
    input=layer0_input,
    filters=CONVfilter,
    filter_shape=(10, 1, 6, 6),
    image_shape=(100,1,76,76),
    )    

test_conv = theano.function([x],conv_out)
result = test_conv(xdata_test)

错误信息如下所示:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Anaconda\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 580, in runfile
    execfile(filename, namespace)
  File "C:/Users/user/Documents/Python Scripts/ffttest.py", line 38, in <module>
    result = test_conv(xdata_test)
  File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 606, in __call__
    storage_map=self.fn.storage_map)
  File "C:\Anaconda\lib\site-packages\theano\gof\link.py", line 205, in raise_with_op
    '\n' + '\n'.join(hints))
TypeError: __init__() takes at least 3 arguments (2 given)     

0 个答案:

没有答案