如何在Theano中进行元素条件索引比较?

时间:2016-04-07 18:23:25

标签: python-2.7 numpy theano

该操作由两个长度相等的数组Xidx组成,其中idx的值可以在0到(k-1)之间变化,并给出k的值。 / p>

这是用于说明这一点的通用Python代码。

import numpy as np

X = np.arange(6) # Just for a sample of elements
k = 3
idx = numpy.array([[0, 1, 2, 2, 0, 1]]).T # Can only contain values in [0..(k-1)]
np.array([X[np.where(idx==i)[0]] for i in range(k)])

示例输出:

array([[0, 4],
       [1, 5],
       [2, 3]])

请注意,实际上我有理由将idx表示为矩阵而不是矢量。它作为计算的一部分初始化为numpy.zeros((n,1)),其中n的大小为X

我尝试在Theano中实现这一点

import theano
import theano.tensor as T

X = T.vector('X')
idx = T.vector('idx')
k = T.scalar()
c = theano.scan(lambda i: X[T.where(T.eq(idx,i))], sequences=T.arange(k)) 
f = function([X,idx,k],c)

但是我在定义c的行收到了此错误:

TypeError: Wrong number of inputs for Switch.make_node (got 1((<int8>,)), expected 3)

在Theano中有一种简单的方法可以实现吗?

1 个答案:

答案 0 :(得分:1)

使用nonzero()并更正idx的尺寸。

此代码解决了问题

import theano
import theano.tensor as T

X = T.vector('X')
idx = T.vector('idx')
k = T.scalar()
c, updates = theano.scan(lambda i: X[T.eq(idx,i).nonzero()], sequences=T.arange(k)) 
f = function([X,idx,k],c)

对于同一个例子,通过使用Theano:

import numpy as np

X = np.arange(6) 
k = 3
idx = np.array([[0, 1, 2, 2, 0, 1]]).T

f(X, idx.T[0], k).astype(int)

这使输出为

array([[0, 4],
       [1, 5],
       [2, 3]])

如果将idx定义为np.array([0, 1, 2, 2, 0, 1]),则可以使用f(X, idx, k)代替。