在PyMC3中使用theano.scan会导致TypeError:切片索引必须为整数或None或具有__index__方法

时间:2018-06-30 18:23:23

标签: theano pymc3 theano.scan

我想在pymc3中使用theano.scan。当我添加两个以上的变量作为%s时遇到问题。这是一个简单的示例:

sequences

导致以下错误:

import numpy as np
import pymc3 as pm
import theano
import theano.tensor as T

a = np.ones(5)
b = np.ones(5)

basic_model = pm.Model()
with basic_model:
    a_plus_b, _ = theano.scan(fn=lambda a, b: a + b, sequences=[a, b])

但是,当我在pymc模型块外运行相同的theano.scan时,一切正常:

Traceback (most recent call last):
File "StackOverflowExample.py", line 23, in <module>
sequences=[a, b])
File "\Anaconda3\lib\site-packages\theano\scan_module\scan.py", line 586, in scan
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
File "\Anaconda3\lib\site-packages\theano\scan_module\scan.py", line 586, in <listcomp>
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
TypeError: slice indices must be integers or None or have an __index__ method

像应该的那样打印a = T.vector('a') b = T.vector('b') a_plus_b, update = theano.scan(fn=lambda a, b: a + b, sequences=[a, b]) a_plus_b_function = theano.function(inputs=[a, b], outputs=a_plus_b, updates=update) a = np.ones(5) b = np.ones(5) print(a_plus_b_function(a, b))

此外,该问题似乎与添加多个[2. 2. 2. 2. 2.]有关。当sequences中只有一个变量而sequences中只有一个变量时,一切工作都很好。以下代码有效:

non-sequences

按预期方式打印a = np.ones(5) c = 2 basic_model = pm.Model() with basic_model: a_plus_c, _ = theano.scan(fn=lambda a, c: a + c, sequences=[a], non_sequences=[c]) a_plus_c_print = T.printing.Print('a_plus_c')(a_plus_c)

注意:我不能仅仅使用+ b来代替theano.scan,因为我的实际功能更加复杂。我实际上想拥有这样的东西:

a_plus_c __str__ = [ 3.  3.  3.  3.  3.]

1 个答案:

答案 0 :(得分:1)

原来这是一个简单的错误!只要将ab定义为张量变量,一切就可以正常工作。将这两行相加即可完成工作:

a = T.as_tensor_variable(np.ones(5))
b = T.as_tensor_variable(np.ones(5))