在阅读Dynamic Graph CNN for Learning on Point Clouds代码时,我遇到了以下片段:
idx_ = tf.range(batch_size) * num_points
idx_ = tf.reshape(idx_, [batch_size, 1, 1])
point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_) <--- what happens here?
point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)
调试该行,确保暗淡无光
point_cloud_flat:(32768,3) nn_idx:(32,1024,20), idx_:(32,1,1)
// indices are (32,1024,20) after broadcasting
阅读tf.gather doc时,我无法理解该功能在输入尺寸之外的尺寸上的作用
答案 0 :(得分:1)
numpy中的等效函数是np.take
,这是一个简单的示例:
import numpy as np
params = np.array([4, 3, 5, 7, 6, 8])
# Scalar indices; (output is rank(params) - 1), i.e. 0 here.
indices = 0
print(params[indices])
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [0, 1, 4]
print(params[indices]) # [4 3 6]
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [2, 3, 4]
print(params[indices]) # [5 7 6]
# Higher rank indices; (output is rank(params) + rank(indices) - 1), i.e. 2 here
indices = np.array([[0, 1, 4], [2, 3, 4]])
print(params[indices]) # equivalent to np.take(params, indices, axis=0)
# [[4 3 6]
# [5 7 6]]
在您的情况下,indices
的排名高于params
,因此输出为rank({params
)+ rank(indices
)-1(即2 + 3-1 = 4,即(32,1024,20,3))。 - 1
是因为此时tf.gather(axis=0)
和axis
必须为等级0(因此是标量)。因此,indices
以“花哨”的索引方式获取第一维(axis=0
)的元素。
已编辑:
简而言之,就您而言(如果我没有误解代码的话)
point_cloud
是(32,1024,3),32个批次1024点,其中3
坐标。 nn_idx
是(32,1024,20),是20个邻居的索引
32批次1024点。索引用于在point_cloud
中建立索引。 nn_idx+idx_
(32、1024、20),索引的20个邻居
32批次1024点。索引用于在point_cloud_flat
中建立索引。point_cloud_neighbors
最后是(32,1024,
20,3),与nn_idx+idx_
相同,除了point_cloud_neighbors
是它们的3个坐标,而nn_idx+idx_
只是它们的索引。