我有一个任意嵌套的参差不齐的张量x
,需要对其进行屏蔽。像这样:
x = tf.ragged.constant([
[[12, 9], [5]],
[[10], [6, 8], [42]],
])
对我来说,最简单的遮罩方法是沿第1轴索引一个元素。有没有办法使行长/分隔相同的衣衫{的arange
:
x = tf.ragged.constant([
[[0, 1], [2]],
[[0], [1, 2], [3]],
])
答案 0 :(得分:1)
尝试以下代码:
import tensorflow as tf
x = tf.ragged.constant([
[[12, 9], [5]],
[[10], [6, 8], [42]],
])
starts = tf.gather(x.nested_row_splits[1], x.nested_row_splits[0])[1:-1]
starts = tf.cast(starts, tf.int32)
len = tf.shape(x.flat_values)[0]
starts = tf.scatter_nd(starts[:,tf.newaxis], starts, [len])
starts = tf.scan(lambda a, x: a + x, starts)
output = tf.range(len) - starts
x = tf.RaggedTensor.from_nested_row_splits(output, x.nested_row_splits)
print(x)
答案 1 :(得分:0)
我设法通过将n-d张量合并到rank2来解决这个问题,计算出一个参差不齐的范围,然后从原来的嵌套行长度上将其基本重塑:
x = tf.ragged.constant([
[[12, 9], [5]],
[[10], [6, 8], [42]],
])
x_2d = x.merge_dims(inner_axis=-1, outer_axis=1)
arange_2d = tf.ragged.range(x_2d.row_lengths())
arange_nd = tf.RaggedTensor.from_nested_row_lengths(
arange_2d.flat_values,
x.nested_row_lengths(),
)
>>> arange_nd
<tf.RaggedTensor [[[0, 1], [2]], [[0], [1, 2], [3]]]>
请参阅this issue,以获取来自Tensorflow维护人员之一的替代解决方案。