我有一个#cellType1:hover a {
color: #00ff00;
}
形状的张量。我还有一个形状为(16, 4096, 3)
的索引张量。我正在尝试沿(16, 32768, 3)
收集值。最初是在pytorch中使用dim=1
函数,如下所示-
gather
请注意,输出# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)
的大小与b
的大小相同。但是,当我应用Tensorflow的idx
函数时,得到的输出却完全不同。发现输出尺寸不匹配,如下所示-
gather
我也尝试使用b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)
,但徒劳无功。见下文-
tf.gather_nd
为什么我会得到不同形状的张量?我想获得与pytorch计算的形状相同的张量。
如何获得与pytorch相同的结果?
答案 0 :(得分:0)
如果我对您的理解正确,那么tf.gather_nd
是您要寻找的。如果没有,请更加清楚。