我正在尝试索引张量以从1d张量获得切片或单个元素。我发现使用numpy
索引[:]
和slice vs tf.gather
的方式(差不多30-40%)时,性能存在显着差异。
另外我观察到tf.gather
在标量上使用时会产生很大的开销(在未堆叠的张量上循环)而不是张量。这是一个已知的问题吗?
示例代码(效率低下):
for node_idxs in graph.nodes():
node_indice_list = tf.unstack(node_idxs)
result = []
for nodeid in node_indices_list:
x = tf.gather(..., nodeid)
y = tf.gather(..., nodeid)
result.append(tf.mul(x,y))
return tf.stack(result)
而不是 示例代码(高效):
for node_idxs in graph.nodes():
x = tf.gather(..., node_idxs)
y = tf.gather(..., node_idxs)
return tf.mul(x, y)
据我所知,第一个低效的实现正在做更多的卸载,堆叠然后循环以及更多聚集操作的工作,但是当我运行的节点的顺序是几百个节点时,我没想到100x减速(正在拆卸和聚集在单个标量上的开销很慢,在第一种情况下,我有更多的聚集操作,每个操作单个元素而不是张量的偏移)。是否有更快的索引方式,我尝试了numpy和slice,结果比收集慢。
答案 0 :(得分:0)
首先,代码并没有真正比较 gather 与 Numpy 索引 - 它比较了矢量化索引(tf.gather)与循环索引(Python“for”循环)。循环很慢也就不足为奇了。
请注意,在 Tensorflow 中无论如何都限制了类似 Numpy 的索引 tensor[idxs]
:
仅整数、切片 (:
)、省略号 (...
)、tf.newaxis (None
) 和
标量 tf.int32/tf.int64 张量是有效索引
因此将 tf.gather
用于一般应用。