我有张量 X ,其形状为(无,56,300,1),另一张张 y ,其形状为< strong>(无,15),这些张量的第一个维度是batch_size,我想用y作为索引得到张量z,z的形状是(None,15,300,1) 即可。有没有可行的方法来做到这一点?
我写了一个简单的代码来测试,因为我发现这对我来说很难,因为在实践中我不知道batch_size(这些张量的第一维是),
这是我的测试代码:
import numpy as np
import tensorflow as tf
# In this test code , batch_size is 4.
# params' shape is (4, 3, 2 ,1), in practice is (None, 56, 300, 1),
params = [
[[['a0'], ['b0']], [['d0'], ['e0']], [['f0'], ['g0']]],
[[['a1'], ['b1']], [['d1'], ['e1']], [['f1'], ['g1']]],
[[['a2'], ['b2']], [['d2'], ['e2']], [['f2'], ['g2']]],
[[['a3'], ['b3']], [['d3'], ['e3']], [['f3'], ['g3']]],
]
# ind's shape is (4, 2) (In practice is (None, 15)),
# so I wanna get output whose's shape is (4, 2, 2, 1), (In practice is (None, 15, 300, 1))
ind = [[1, 0], [0, 2], [2, 0], [2, 1]]
#ouput = [
# [[['d0'], ['e0']], [['a0'], ['b0']]],
# [[['a1'], ['b1']], [['f1'], ['g1']]],
# [[['f2'], ['g2']], [['a2'], ['b2']]],
# [[['f3'], ['g3']], [['d3'], ['e3']]]
#]
with tf.variable_scope('gather') as scope:
tf_par = tf.constant(params)
tf_ind = tf.constant(ind)
res = tf.gather_nd(tf_par, tf_ind)
with tf.Session() as sess:
init = tf.global_variables_initializer()
print sess.run(res)
print res
答案 0 :(得分:3)
使用x
沿第二维切片ind
,即切片
x
的张量(d0, d1, d2,...)
,d0
可能是None
,ind
的索引(d0, n1)
的张量,y
的张量(d0, n1, d2, ...)
,您可以使用tf.gather_nd
和tf.shape
在运行时获取形状:
ind_shape = tf.shape(ind)
ndind = tf.stack([tf.tile(tf.range(ind_shape[0])[:, None], [1, ind_shape[1]]),
ind], axis=-1)
y = tf.gather_nd(x, ndind)
答案 1 :(得分:0)
对于您认为的结果,您应该使用:
ind = [[0, 1], [0, 0], [1, 0], [1, 2], [2, 2], [2, 0], [3, 2], [3, 1]]
<强>更新强>
您可以使用此代码获取所需内容,并使用当前输入:
with tf.variable_scope('gather') as scope:
tf_par = tf.constant(params)
tf_ind = tf.constant(ind)
tf_par_shape = tf.shape(tf_par)
tf_ind_shape = tf.shape(tf_ind)
tf_r = tf.div(tf.range(0, tf_ind_shape[0] * tf_ind_shape[1]), tf_ind_shape[1])
tf_r = tf.expand_dims(tf_r, 1)
tf_ind = tf.expand_dims(tf.reshape(tf_ind, shape = [-1]), 1)
tf_ind = tf.concat([tf_r, tf_ind], axis=1)
res = tf.gather_nd(tf_par, tf_ind)
res = tf.reshape(res, shape = (-1, tf_ind_shape[1], tf_par_shape[2], tf_par_shape[3]))