我有2个张量a
和b
,它们具有以下形状
>>K.int_shape(a)
(None, 5 , 2)
>>K.int_shape(b)
(None, 5)
我想要的是张量c
>>K.int_shape(c)
(None, 2)
这样,沿着轴0,您可以选择b
中最大元素的索引,并使用它来沿轴1索引a
。
示例 - 说我有
a = np.array([[[2, 7],
[6, 5],
[9, 9],
[4, 2],
[5, 9]],
[[8, 1],
[8, 8],
[3, 9],
[9, 2],
[9, 1]],
[[3, 9],
[6, 4],
[5, 7],
[5, 2],
[5, 6]],
[[7, 5],
[9, 9],
[9, 5],
[9, 8],
[5, 7]],
[[6, 3],
[1, 7],
[3, 6],
[8, 2],
[3, 2]],
[[6, 4],
[5, 9],
[8, 6],
[5, 2],
[5, 2]],
[[2, 6],
[6, 5],
[3, 1],
[6, 2],
[6, 4]]])
我有
b = np.array([[ 0.27, 0.25, 0.23, 0.06, 0.19],
[ 0.3 , 0.13, 0.17, 0.2 , 0.2 ],
[ 0.08, 0.04, 0.40, 0.36, 0.12],
[ 0.3 , 0.33, 0.11, 0.07, 0.19],
[ 0.15, 0.21, 0.30, 0.12, 0.22],
[ 0.3 , 0.13, 0.23, 0.1 , 0.23],
[ 0.26, 0.35 , 0.25 , 0.07, 0.07]])
我期望c
成为
c = np.zeros((7,2))
for i in range(7):
ind = np.argmax(b[i, :])
c[i, :] = a[i, ind, :]
c
array([[ 2., 7.],
[ 8., 1.],
[ 5., 7.],
[ 9., 9.],
[ 3., 6.],
[ 6., 4.],
[ 6., 5.]])
答案 0 :(得分:1)
使用Tensorflow进行后端(我对Theano知之甚少),使用tf.gather_nd()
:
import keras.backend as K
import tensorflow as tf
# `a` and `b` the numpy arrays defined in the question
A = tf.constant(a)
B = tf.constant(b)
# Obtaining your max indices over axis 1, which will be used as indices for axis 1 of A:
col_ind = K.argmax(B, axis=1)
# Creating row range, which will be used as indices for axis 0 of A:
row_ind = K.arange(col_ind.shape[0], dtype='int64')
# Stacking the indices together:
ind = K.stack((row_ind, col_ind), axis=-1)
# Gathering the results:
c = tf.gather_nd(A, ind) # no equivalent I know in K, and no idea about theano...
with tf.Session() as sess:
print(c.eval())
# [[2 7]
# [8 1]
# [5 7]
# [9 9]
# [3 6]
# [6 4]
# [6 5]]
答案 1 :(得分:0)
找到了解决方案
A = K.constant(a)
B = K.constant(b)
mxidx = K.argmax(B, axis=1)
c = K.map_fn(lambda i: A[i, mxidx[i], :], K.arange(A.shape[0], dtype='int64'))
print K.eval(c)
array([[ 2., 7.],
[ 8., 1.],
[ 5., 7.],
[ 9., 9.],
[ 3., 6.],
[ 6., 4.],
[ 6., 5.]], dtype=float32)
编辑:添加运行时信息
%timeit K.eval(c)
The slowest run took 9.76 times longer than the fastest. This could mean
that an intermediate result is being cached.
100000 loops, best of 3: 12.2 µs per loop