我的张量T1
为shape (a,b,c)
,即(batch_size, 64, 128)
。我有T2
中的另一个张量shape (d,e,f)
,即(batch_size, 64, 49)
。 T2
在T1
中被视为索引axis=2(f, i.e. 49)
从0到64的索引。对于T1
中axis=1(b, i.e. 64)
的每个点,我想选择由T2
索引的那个点的所有索引。
输入:T1 -> (batch_size, 64, 128)
T2 -> (batch_size, 64, 49)
。索引张量
Desired Output: (batch_size, 64, 49, 128)
我尝试使用batch_gather。但是,它引发了非法内存访问错误。 以下是代码。
T3 = tf.batch_gather(tf.expand_dims(T1,2), T2)
感谢您提供的任何帮助