Theano扫描并重复

时间:2016-02-12 00:34:20

标签: python machine-learning theano pymc3

在Theano中,有一个选项可以使用重复函数T.repeat(A,B)并提供一对向量,这样A[i]的每个元素都会重复B[i]次。

不幸的是,这个操作没有定义的渐变(它会抛出一个未实现的异常)这是一个问题,因为我试图将它与Pymc3的基于渐变的采样器一起使用。

我想我可以使用scan函数解决这个问题,并为两个向量的每个元素递归调用repeat,但是我的代码不起作用,可能是因为我错误地调用了scan。任何人都可以帮助我理解为什么以下代码不起作用?

A = T.dvector('A')
B = T.ivector('B')
A.tag.test_value = np.array(np.random.rand(2), dtype = "float32")
B.tag.test_value = np.array(np.random.rand(2), dtype = "int32")
th.config.compute_test_value = 'warn'

results, updates = th.scan(fn = lambda prior_result, A, B: A.repeat(B),
                          sequences = [A, B],
                          outputs_info = T.constant([1,4,4,4]))

b = th.function(inputs=[A,B], outputs=results.flatten())
b([1],[4])

我希望这会返回[1,1,1,1],但它会返回以下错误。

    395     except AttributeError:
    396         return _wrapit(a, 'repeat', repeats, axis)
--> 397     return repeat(repeats, axis)
    398 
    399 

ValueError: operands could not be broadcast together with shape (1,) (4,)

我在Pymc3 github上提出了一个issue来看看这是否应该更加永久地修复,但我认为这是一个很好的机会,无论如何都要为我了解Theano,如果我能解决问题也许我可以回馈这个项目。

1 个答案:

答案 0 :(得分:0)

我在这里看到两件事:

  1. lambda表达式中的排序错误:它应该是A,B,prior_result(现在B被视为outputs_info)
  2. A.repeat(B)的形状不同于prior_result的形状(在编译的这个阶段)
  3. 快速修复:只需从scan的参数中删除outputs_info(以及从lambda中删除prior_result),你就会得到[1,1,1,1]。