Tensorflow参差不齐的张量中的蒙版值

时间:2020-06-04 14:39:01

标签: tensorflow tensorflow2.0

我的张量参差不齐:

tf.ragged.constant([[[17712], [16753], [11850], [13028], [10155], [15734, 15938], [126], [10135], [17665]]], dtype=tf.int32)

我想将长度大于1的行中的元素值设置为特定值。例如:

tf.ragged.constant([[[17712], [16753], [11850], [13028], [10155], [15734, 0], [126], [10135], [17665]]], dtype=tf.int32)

如何在Tensorflow中表达这种转变?

1 个答案:

答案 0 :(得分:1)

参差不齐的张量总是使事情变得棘手,但这是一种可能的实现方式:

import tensorflow as tf

# Using an intermediate NumPy array avoids having the second dimension as ragged
a = tf.ragged.constant([[[17712], [16753], [11850], [13028], [10155], [15734, 15938],
                         [126], [10135], [17665]]], dtype=tf.int32)
# Index from which values are replaced
replace_from_idx = 1
# Replacement value
new_value = 0

# Get size of each element in the last dimension
s = a.row_lengths(axis=-1)
# Make ragged ranges
r = tf.ragged.range(s.flat_values)
# Un-flatten
r = tf.RaggedTensor.from_row_lengths(r, a.row_lengths(1))
# Replace values
m = tf.dtypes.cast(r < replace_from_idx, a.dtype)
out = a * m + new_value * (1 - m)
print(out.to_list())
# [[[17712], [16753], [11850], [13028], [10155], [15734, 0], [126], [10135], [17665]]]