从张量流张量获取由另一个张量索引的元素。尝试了batch_gather,但是遇到了CUDA_Illegal_Memory

时间:2019-02-14 03:40:09

标签: python tensorflow

我的张量T1shape (a,b,c),即(batch_size, 64, 128)。我有T2中的另一个张量shape (d,e,f),即(batch_size, 64, 49)T2T1中被视为索引axis=2(f, i.e. 49)从0到64的索引。对于T1axis=1(b, i.e. 64)的每个点,我想选择由T2索引的那个点的所有索引。

输入:T1 -> (batch_size, 64, 128)        T2 -> (batch_size, 64, 49)。索引张量

   Desired Output: (batch_size, 64, 49, 128)

我尝试使用batch_gather。但是,它引发了非法内存访问错误。 以下是代码。

T3 = tf.batch_gather(tf.expand_dims(T1,2), T2)

感谢您提供的任何帮助

0 个答案:

没有答案