在强制转换后无法更新共享张量变量的子集

时间:2015-06-03 17:54:18

标签: python-2.7 numpy theano

我有以下代码:

import theano.tensor as T

Words = theano.shared(value = U, name = 'Words')
zero_vec_tensor = T.vector()
zero_vec = np.zeros(img_w, dtype = theano.config.floatX)
set_zero = theano.function([zero_vec_tensor], updates=[(Words, T.set_subtensor(Words[0,:], zero_vec_tensor))])

编译正常(其中U是dtype float64的numpy数组。)

为防止将来出现类型错误,我想将共享张量Words投射到float32(或theano.config.floatX,这与我将floatX设置为float32相同在配置文件中)。

我添加Words = T.cast(Words, dtype = theano.config.floatX)然后我收到以下错误:

TypeError: ('update target must be a SharedVariable', Elemwise{Cast{float32}}.0)

我不明白为什么。根据此question,使用set_subtensor应该允许我更新共享变量的子集。

如何在能够更新共享张量的同时转换它?

1 个答案:

答案 0 :(得分:1)

问题是您正在尝试更新符号变量,而不是共享变量。

U = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
Words = theano.shared(value=U, name='Words')
zero_vec_tensor = T.vector()
set_zero = theano.function([zero_vec_tensor], updates=[(Words, T.set_subtensor(Words[0, :], zero_vec_tensor))])

工作正常,因为你要更新的东西,Words是一个共享变量。

U = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
Words = theano.shared(value=U, name='Words')
Words = T.cast(Words, dtype = theano.config.floatX)
zero_vec_tensor = T.vector()
set_zero = theano.function([zero_vec_tensor], updates=[(Words, T.set_subtensor(Words[0, :], zero_vec_tensor))])

不起作用,因为现在Words不再是共享变量,它是一个符号变量,在执行时,会计算将共享变量中的值转换为theano.config.floatX

共享变量的dtype由分配给它的值决定。所以你可能只需要改变U的类型:

U = np.array([[1, 2, 3], [4, 5, 6]], dtype=theano.config.floatX)

或者使用numpy而不是象征性地投射它:

U = np.dtype(theano.config.floatX).type(U)