我有一个(16, 4096, 3)
形状的张量。我还有一个形状为(16, 32768, 3)
的索引张量。我正在尝试沿dim=1
收集值。最初是使用gather function在pytorch中完成的,如下所示-
# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)
请注意,输出b
的大小与idx
的大小相同。但是,当我应用Tensorflow的gather
函数时,得到的输出却完全不同。发现输出尺寸不匹配,如下所示-
b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)
我也尝试使用tf.gather_nd
,但徒劳无功。见下文-
b = tf.gather_nd(a, idx)
# b.shape (16, 32768)
为什么我会得到不同形状的张量? 我想获得与pytorch计算的形状相同的张量。
换句话说,我想知道torch.gather的张量流等效值。
答案 0 :(得分:1)
对于2D情况,有一种方法可以做到:
# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)
但是,对于ND病例,这种方法可能非常复杂
答案 1 :(得分:0)
这个“应该”是使用tf.gather_nd的一般解决方案(我只测试了沿最后一个轴的2级和3级张量):
def torch_gather(x, indices, gather_axis):
# if pytorch gather indices are
# [[[0, 10, 20], [0, 10, 20], [0, 10, 20]],
# [[0, 10, 20], [0, 10, 20], [0, 10, 20]]]
# tf nd_gather needs to be
# [[0,0,0], [0,0,10], [0,0,20], [0,1,0], [0,1,10], [0,1,20], [0,2,0], [0,2,10], [0,2,20],
# [1,0,0], [1,0,10], [1,0,20], [1,1,0], [1,1,10], [1,1,20], [1,2,0], [1,2,10], [1,2,20]]
# create a tensor containing indices of each element
all_indices = tf.where(tf.fill(indices.shape, True))
gather_locations = tf.reshape(indices, [indices.shape.num_elements()])
# splice in our pytorch style index at the correct axis
gather_indices = []
for axis in range(len(indices.shape)):
if axis == gather_axis:
gather_indices.append(gather_locations)
else:
gather_indices.append(all_indices[:, axis])
gather_indices = tf.stack(gather_indices, axis=-1)
gathered = tf.gather_nd(x, gather_indices)
reshaped = tf.reshape(gathered, indices.shape)
return reshaped
答案 2 :(得分:0)
对于最后一个轴的收集,我们可以使用一般 ND 情况下的 2D-reshape 技巧,然后使用上面的@LiShaoyuan 2D 代码
# last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering
def torch_gather(param, id_tensor):
# 2d-gather torch equivalent from @LiShaoyuan above
def gather2d(target, id_tensor):
idx = tf.stack([tf.range(tf.shape(id_tensor)[0]),id_tensor[:,0]],axis=-1)
result = tf.gather_nd(target,idx)
return tf.expand_dims(result,axis=-1)
target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D
target_shape = id_tensor.shape
id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index
result = gather2d(target, id_tensor)
return tf.reshape(result, target_shape)