张量流中的衣衫a的范围

时间:2020-10-20 00:19:36

标签: python tensorflow

我有一个任意嵌套的参差不齐的张量x,需要对其进行屏蔽。像这样:

x = tf.ragged.constant([
    [[12, 9], [5]],
    [[10], [6, 8], [42]],
])

对我来说,最简单的遮罩方法是沿第1轴索引一个元素。有没有办法使行长/分隔相同的衣衫{的arange

x = tf.ragged.constant([
    [[0, 1], [2]],
    [[0], [1, 2], [3]],
])

2 个答案:

答案 0 :(得分:1)

尝试以下代码:

import tensorflow as tf

x = tf.ragged.constant([
    [[12, 9], [5]],
    [[10], [6, 8], [42]],
])

starts = tf.gather(x.nested_row_splits[1], x.nested_row_splits[0])[1:-1]
starts = tf.cast(starts, tf.int32)
len = tf.shape(x.flat_values)[0]
starts = tf.scatter_nd(starts[:,tf.newaxis], starts, [len])
starts = tf.scan(lambda a, x: a + x, starts)
output = tf.range(len) - starts
x = tf.RaggedTensor.from_nested_row_splits(output, x.nested_row_splits)

print(x)

答案 1 :(得分:0)

我设法通过将n-d张量合并到rank2来解决这个问题,计算出一个参差不齐的范围,然后从原来的嵌套行长度上将其基本重塑:

x = tf.ragged.constant([
    [[12, 9], [5]],
    [[10], [6, 8], [42]],
])

x_2d = x.merge_dims(inner_axis=-1, outer_axis=1)
arange_2d = tf.ragged.range(x_2d.row_lengths())
arange_nd = tf.RaggedTensor.from_nested_row_lengths(
    arange_2d.flat_values,
    x.nested_row_lengths(),
)

>>> arange_nd
<tf.RaggedTensor [[[0, 1], [2]], [[0], [1, 2], [3]]]>

请参阅this issue,以获取来自Tensorflow维护人员之一的替代解决方案。