我似乎无法找到解决方案。给定两个theano张量a和b,我想在张量a中找到b中元素的索引。这个例子有帮助,比如a = [1,5,10,17,23,39]和b = [1,10,39],我希望结果是张量a中b值的指数,即[ 0,2,5]。
花了一些时间后,我认为最好的方法是使用扫描;这是我最小的例子。
def getIndices(b_i, b_v, ar):
pI_subtensor = pI[b_i]
return T.set_subtensor(pI_subtensor, np.where(ar == b_v)[0])
ar = T.ivector()
b = T.ivector()
pI = T.zeros_like(b)
result, updates = theano.scan(fn=getIndices,
outputs_info=None,
sequences=[T.arange(b.shape[0], dtype='int32'), b],
non_sequences=ar)
get_proposal_indices = theano.function([b, ar], outputs=result)
d = get_proposal_indices( np.asarray([1, 10, 39], dtype=np.int32), np.asarray([1, 5, 10, 17, 23, 39], dtype=np.int32) )
我收到错误:
TypeError: Trying to increment a 0-dimensional subtensor with a 1-dimensional value.
返回语句行中的。此外,输出需要是单个张量的形状b,我不确定这是否会得到所需的结果。任何建议都会有所帮助。
答案 0 :(得分:1)
这完全取决于你的阵列有多大。只要它适合内存,您可以按照以下步骤进行操作
import numpy as np
import theano
import theano.tensor as T
aa = T.ivector()
bb = T.ivector()
equality = T.eq(aa, bb[:, np.newaxis])
indices = equality.nonzero()[1]
f = theano.function([aa, bb], indices)
a = np.array([1, 5, 10, 17, 23, 39], dtype=np.int32)
b = np.array([1, 10, 39], dtype=np.int32)
f(a, b)
# outputs [0, 2, 5]