在"扫描期间存储动态维度数组列表"在theano中的功能

时间:2015-11-07 14:17:35

标签: python theano

我想在theano scan函数中创建一个动态数组列表,如下所示。

import theano
import theano.tensor as T
import numpy as np

x_var = T.dmatrix('x_var')
y_var = T.dmatrix('y_var')

def _op(y, h_tm1):
    h_next = T.concatenate([h_tm1, y.dimshuffle('x', 0)])
    return h_next

h_res, _ = theano.scan(fn=_op,
                       sequences=y_var,
                       outputs_info=x_var)

fn = theano.function(inputs=[x_var, y_var],
                     outputs=h_res)

x = np.asarray([[0, 0, 0]],
               dtype=theano.config.floatX)
y = np.asarray([[1,1,1], [2,2,2]],
               dtype=theano.config.floatX)
res = fn(x, y)
print res

我想要做的是在每次迭代中,将y_var的一个向量追加到h_tm1的末尾,然后返回新矩阵。但是,似乎扫描无法在运行时中存储动态维矩阵列表?

我发现错误发生在h_next = T.concatenate([h_tm1, y.dimshuffle('x', 0)])

theano报告错误如

  

ValueError:无法将形状(2,3)的输入数组广播为形状(1,3)",省略了有关此错误的其他详细信息。

0 个答案:

没有答案