更新已使用theano.tensor.cast()转换的变量

时间:2015-03-25 12:40:07

标签: python theano

我正在尝试更新函数中的theano变量,简化如下:

copy_func = theano.function(
    inputs=[idx],
    updates=[
        (a_variable, T.set_subtensor(a_variable[some_ptr], another_variable[idx]))
    ]
)

我的问题是我收到了错误

TypeError: ('update target must be a SharedVariable', Elemwise{Cast{int32}}.0)

我获取此变量的方法是使用以下内容(主要从deeplearning.net教程中复制)(another_variable同样初始化):

a_variable = theano.shared(np.asarray(data,
                               dtype=theano.config.floatX),
                 borrow=True)
print type(a_variable)
a_variable = T.cast(a_variable, 'int32')
print type(a_variable)

打印

<class 'theano.tensor.sharedvar.TensorSharedVariable'>
<class 'theano.tensor.var.TensorVariable'>

即,变量不再是#34;共享&#34;,解释错误。 这是有道理的,因为我猜这个变量现在只是原始共享浮点数的一个转换视图。但是,如何更新有效转换的变量?

1 个答案:

答案 0 :(得分:1)

我自己解决了这个问题,答案当然是显而易见的。

我没有使用已转换的版本覆盖a_variable变量,而是保留了未发布的版本:

a_variable_casted = T.cast(a_variable, 'int32')

现在已在a_variable上完成更新,而a_variable_casted则用于执行早先使用的a_variable计算。

显然有一种更优雅的方式可以做到这一点,在这种情况下,我很乐意听到它!