tf.gather的索引比输入数据高吗?

时间:2018-12-02 08:04:32

标签: python tensorflow

在阅读Dynamic Graph CNN for Learning on Point Clouds代码时,我遇到了以下片段:

  idx_ = tf.range(batch_size) * num_points
  idx_ = tf.reshape(idx_, [batch_size, 1, 1]) 

  point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
  point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_)  <--- what happens here?
  point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)

调试该行,确保暗淡无光

point_cloud_flat:(32768,3) nn_idx:(32,1024,20), idx_:(32,1,1) 
// indices are (32,1024,20) after broadcasting

阅读tf.gather doc时,我无法理解该功能在输入尺寸之外的尺寸上的作用

1 个答案:

答案 0 :(得分:1)

numpy中的等效函数是np.take,这是一个简单的示例:

import numpy as np

params = np.array([4, 3, 5, 7, 6, 8])

# Scalar indices; (output is rank(params) - 1), i.e. 0 here.
indices = 0
print(params[indices])

# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [0, 1, 4]
print(params[indices])  # [4 3 6]

# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [2, 3, 4]
print(params[indices])  # [5 7 6]

# Higher rank indices; (output is rank(params) + rank(indices) - 1), i.e. 2 here
indices = np.array([[0, 1, 4], [2, 3, 4]])
print(params[indices])  # equivalent to np.take(params, indices, axis=0)
# [[4 3 6]
# [5 7 6]]

在您的情况下,indices的排名高于params,因此输出为rank({params)+ rank(indices)-1(即2 + 3-1 = 4,即(32,1024,20,3))。 - 1是因为此时tf.gather(axis=0)axis必须为等级0(因此是标量)。因此,indices以“花哨”的索引方式获取第一维(axis=0)的元素。

已编辑

简而言之,就您而言(如果我没有误解代码的话)

  • point_cloud是(32,1024,3),32个批次1024点,其中3 坐标。
  • nn_idx是(32,1024,20),是20个邻居的索引 32批次1024点。索引用于在point_cloud中建立索引。
  • nn_idx+idx_(32、1024、20),索引的20个邻居 32批次1024点。索引用于在point_cloud_flat中建立索引。
  • point_cloud_neighbors最后是(32,1024, 20,3),与nn_idx+idx_相同,除了point_cloud_neighbors是它们的3个坐标,而nn_idx+idx_只是它们的索引。