您如何在TensorFlow中沿粗糙的维度索引RaggedTensor?

时间:2019-03-27 00:57:25

标签: python tensorflow ragged

我需要通过沿参差不齐的维度索引来获取参差不齐的张量中的值。某些索引功能([:, :x][:, -x:][:, x:y])有效,但不能直接索引([:, x]):

R = tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]])
print(R[:, :2]) # RaggedTensor([[1, 2], [4, 5]])
print(R[:, 1:2]) # RaggedTensor([[2], [5]])
print(R[:, 1])  # ValueError: Cannot index into an inner ragged dimension.

documentation解释了失败的原因:

  

RaggedTensors支持多维索引和切片,其中一个   限制:不允许索引到衣衫dimension的维度。这个   这种情况是有问题的,因为指示的值可能存在于某些行中   但其他人没有。在这种情况下,是否应该(1)   引发IndexError; (2)使用默认值;或(3)跳过该值   并返回一个行数少于开始的张量。以下   Python的指导原则(“面对歧义,拒绝   的诱惑”),我们目前不允许此操作。

这很有意义,但是我实际上如何实现选项1、2和3?我是否必须将衣衫agged的数组转换为张量的Python数组,并手动遍历它们?有没有更有效的解决方案?一个可以在TensorFlow图中100%运作而无需通过Python解释器的机器?

1 个答案:

答案 0 :(得分:0)

如果您具有2D RaggedTensor,则可以通过以下方式获得行为(3):

TypeError: arguments did not match any overloaded call:
  QgsVectorLayer.select(QgsRectangle, bool): argument 1 has unexpected type 'QgsFeature'
  QgsVectorLayer.select(int): argument 1 has unexpected type 'QgsFeature'
  QgsVectorLayer.select(unknown-type): argument 1 has unexpected type 'QgsFeature'

您可以通过添加rt.nrows()== tf.size(slice.flat_values)的断言来获得行为(1):

-multidex

要获得行为(2),我认为最简单的方法可能是连接一个默认值的向量,然后再次切片:

def get_column_slice_v3(rt, column):
  assert column >= 0  # Negative column index not supported
  slice = rt[:, column:column+1]
  return slice.flat_values

可以扩展这些以支持更高维度的参差不齐的张量,但是逻辑变得有点复杂。此外,应该有可能将其扩展为支持负列索引。