在tensorflow和pytorch中看到的collect()函数的不同行为

时间:2018-08-31 16:02:18

标签: python tensorflow pytorch

我有一个#cellType1:hover a { color: #00ff00; } 形状的张量。我还有一个形状为(16, 4096, 3)的索引张量。我正在尝试沿(16, 32768, 3)收集值。最初是在pytorch中使用dim=1函数,如下所示-

gather

请注意,输出# a.shape (16L, 4096L, 3L) # idx.shape (16L, 32768L, 3L) b = a.gather(1, idx) # b.shape (16L, 32768L, 3L) 的大小与b的大小相同。但是,当我应用Tensorflow的idx函数时,得到的输出却完全不同。发现输出尺寸不匹配,如下所示-

gather

我也尝试使用b = tf.gather(a, idx, axis=1) # b.shape (16, 16, 32768, 3, 3) ,但徒劳无功。见下文-

tf.gather_nd

为什么我会得到不同形状的张量?我想获得与pytorch计算的形状相同的张量。

如何获得与pytorch相同的结果?

1 个答案:

答案 0 :(得分:0)

如果我对您的理解正确,那么tf.gather_nd是您要寻找的。如果没有,请更加清楚。