我的张量参差不齐:
<tf.RaggedTensor [[[1, 2]], [[12, 13]], [[16, 17], []], [[18, 19], [20]]]>
我的问题是如何从中删除空元素?结果就是
<tf.RaggedTensor [[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [20]]]>
谢谢!
答案 0 :(得分:0)
Boi 太难了:
import tensorflow as tf
@tf.function
def remove_empty_lists(rt):
nrl = rt.nested_row_lengths()
empties = tf.squeeze(tf.where(nrl[1] == 0), axis=1)
diff = tf.expand_dims(rt.nested_row_splits[0][1:], axis=0) - tf.expand_dims(empties, axis=1)
diff_absolute = tf.where(diff<=0, diff.dtype.limits[1], diff)
diff_min = tf.argmin(diff_absolute, axis=1)
counts = tf.unique_with_counts(diff_min)
to_subtract = tf.scatter_nd(tf.expand_dims(counts.y, 1),counts.count,nrl[0].shape)
non_empties = tf.squeeze(tf.where(nrl[1] != 0), axis=1)
nrl_updated = tf.gather(nrl[1], non_empties)
result = tf.RaggedTensor.from_nested_row_lengths(rt.flat_values, (nrl[0] - tf.cast(to_subtract, tf.int64), nrl_updated))
return result
if __name__ == "__main__":
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]], [[16, 17],[]], [[18, 19], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13],[]], [[16, 17]], [[18, 19], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[],[12, 13]], [[16, 17]], [[18, 19], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [20], []]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[],[12, 13],[],[]], [[],[],[],[],[16, 17]], [[18, 19], [20]]])))
# all yields [[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [20]]]
# NOTE: I only tested this for
# - len(rt.nested_row_lengths()) == 2
# - only 0 values in rt.nested_row_lengths()[1]