theano在第二张量中找到张量元素的指数

时间:2015-08-26 04:41:13

标签: theano

我似乎无法找到解决方案。给定两个theano张量a和b,我想在张量a中找到b中元素的索引。这个例子有帮助,比如a = [1,5,10,17,23,39]和b = [1,10,39],我希望结果是张量a中b值的指数,即[ 0,2,5]。

花了一些时间后,我认为最好的方法是使用扫描;这是我最小的例子。

def getIndices(b_i, b_v, ar):
    pI_subtensor = pI[b_i]
    return T.set_subtensor(pI_subtensor, np.where(ar == b_v)[0])

ar = T.ivector()
b = T.ivector()
pI = T.zeros_like(b)

result, updates = theano.scan(fn=getIndices,
                              outputs_info=None,
                              sequences=[T.arange(b.shape[0], dtype='int32'), b],
                              non_sequences=ar)

get_proposal_indices = theano.function([b, ar], outputs=result)

d = get_proposal_indices( np.asarray([1, 10, 39], dtype=np.int32), np.asarray([1, 5, 10, 17, 23, 39], dtype=np.int32) )

我收到错误:

TypeError: Trying to increment a 0-dimensional subtensor with a 1-dimensional value.
返回语句行中的

。此外,输出需要是单个张量的形状b,我不确定这是否会得到所需的结果。任何建议都会有所帮助。

1 个答案:

答案 0 :(得分:1)

这完全取决于你的阵列有多大。只要它适合内存,您可以按照以下步骤进行操作

import numpy as np
import theano
import theano.tensor as T

aa = T.ivector()
bb = T.ivector()

equality = T.eq(aa, bb[:, np.newaxis])
indices = equality.nonzero()[1]

f = theano.function([aa, bb], indices)

a = np.array([1, 5, 10, 17, 23, 39], dtype=np.int32)
b = np.array([1, 10, 39], dtype=np.int32)

f(a, b)

# outputs [0, 2, 5]