如何在Theano的TensorVariable上执行范围?

时间:2015-11-26 01:49:00

标签: python theano

如何在Theano的TensorVariable上执行范围?

示例:

import theano.tensor as T
from theano import function

constant = T.dscalar('constant')
n_iters = T.dscalar('n_iters')
start = T.dscalar('start')
result = start  

for iter in range(n_iters):
    result = start + constant

f = function([start, constant, n_iters], result)
print('f(0,2,5): {0}'.format(f(1,2)))

返回错误:

Traceback (most recent call last):
  File "test_theano.py", line 9, in <module>
    for iter in range(n_iters):
TypeError: range() integer end argument expected, got TensorVariable.

在Theano的TensorVariable上使用范围的正确方法是什么?

1 个答案:

答案 0 :(得分:7)

目前还不清楚这段代码的用途是什么,因为即使循环有效,它也不会计算任何有用的东西:result总是等于startconstant的总和无论迭代次数如何。

我假设打算像这样计算result

for iter in range(n_iters):
    result = result + constant

问题在于你将符号性,延迟执行,Theano操作与非符号,立即执行的Python操作混合在一起。

range是一个Python函数,它需要一个Python整数参数,但是你提供了一个Theano符号值(n_iters,它有一个double类型而不是一个整数类型,但我们假设它实际上是一个iscalar代替dscalar)。就Python而言,所有Theano符号张量都只是对象:Theano库中某个类类型的实例;他们绝对不是整数。即使你眯着眼睛并试图假装Theano iscalar看起来像一个Python整数,它仍然不起作用,因为Python操作立即执行,这意味着n_iters需要立即获得一个值。另一方面,Theano没有任何iscalar的值,除非通过调用已编译的Theano函数(或通过eval)提供一个{。}。

要创建符号范围,您可以使用theano.tensor.arange,其操作方式与NumPy的arange类似,但符号上有效。

示例:

import theano.tensor as T
from theano import function

my_range_max = T.iscalar('my_range_max')
my_range = T.arange(my_range_max)

f = function([my_range_max], my_range)
print('f(10): {0}'.format(f(10)))

输出:

f(10): [0 1 2 3 4 5 6 7 8 9]

通过使n_iters成为符号变量,您隐含地说“我不知道在此循环中需要进行多少次迭代,直到为n_iters稍后提供值”。在这种情况下,您必须使用符号循环而不是Python for循环。在后一种情况下,你必须告诉Python多少次迭代现在,你不能推迟这个决定,直到稍后为n_iters提供一个值。要解决此问题,您需要切换到由Theano的scan运算符提供的符号循环。

以下代码已更改为使用scan(以及其他假设的更改)。

import theano
import theano.tensor as T
from theano import function

constant = T.dscalar('constant')
n_iters = T.iscalar('n_iters')
start = T.dscalar('start')

results, _ = theano.scan(lambda result, constant: result + constant,
                         outputs_info=[start], non_sequences=[constant], n_steps=n_iters)

f = function([start, constant, n_iters], results[-1])
print('f(0,2,5): {0}'.format(f(0, 2, 5)))