张量流,从参差不齐的张量中删除空元素

时间:2020-09-13 20:58:56

标签: tensorflow tensorflow2.0

我的张量参差不齐:

<tf.RaggedTensor [[[1, 2]], [[12, 13]], [[16, 17], []], [[18, 19], [20]]]>

我的问题是如何从中删除空元素?结果就是

<tf.RaggedTensor [[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [20]]]>

谢谢!

1 个答案:

答案 0 :(得分:0)

Boi 太难了:

import tensorflow as tf

@tf.function
def remove_empty_lists(rt):
  nrl = rt.nested_row_lengths()
  empties = tf.squeeze(tf.where(nrl[1] == 0), axis=1)
  diff = tf.expand_dims(rt.nested_row_splits[0][1:], axis=0) - tf.expand_dims(empties, axis=1)
  diff_absolute = tf.where(diff<=0, diff.dtype.limits[1], diff)
  diff_min = tf.argmin(diff_absolute, axis=1)
  counts = tf.unique_with_counts(diff_min)
  to_subtract = tf.scatter_nd(tf.expand_dims(counts.y, 1),counts.count,nrl[0].shape)
  non_empties = tf.squeeze(tf.where(nrl[1] != 0), axis=1)
  nrl_updated = tf.gather(nrl[1], non_empties)
  result = tf.RaggedTensor.from_nested_row_lengths(rt.flat_values, (nrl[0] - tf.cast(to_subtract, tf.int64), nrl_updated))
  return result


if __name__ == "__main__":
  tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]],          [[16, 17],[]],          [[18, 19], [20]]])))
  tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13],[]],       [[16, 17]],             [[18, 19], [20]]])))
  tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[],[12, 13]],       [[16, 17]],             [[18, 19], [20]]])))
  tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]],          [[16, 17]],             [[18, 19], [20], []]])))
  tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]],          [[16, 17]],             [[18, 19], [], [20]]])))
  tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[],[12, 13],[],[]], [[],[],[],[],[16, 17]], [[18, 19], [20]]])))
  # all yields [[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [20]]]
  # NOTE: I only tested this for
  # - len(rt.nested_row_lengths()) == 2
  # - only 0 values in rt.nested_row_lengths()[1]