如何在张量流中的四维张量中省略零?

时间:2017-02-03 20:29:00

标签: python tensorflow

说我有一个张量:

import tensorflow as tf
t = tf.Variable([[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.], [77., 0., 0., 12., 0., 0., 33., 55., 0.]],
                 [[0., 132., 0., 0., 234., 0., 1., 24., 0.], [43., 0., 0., 124., 0., 0., 0., 52., 645]]]])

我想省略零并留下一个形状的张量:(1,2,2,4),其中4是张量中非零元素的数量,如

t = tf.Variable([[[[235., 1006., 23., 42], [77., 12., 33., 55.]],
                 [[132., 234., 1., 24.], [43., 124., 52., 645]]]])

我使用布尔掩码在1-D张量上执行此操作。如何省略4-D张量中的零点。可以推广到更高级别吗?

1 个答案:

答案 0 :(得分:2)

使用TensorFlow 1.12:

import tensorflow as tf

def batch_of_vectors_nonzero_entries(batch_of_vectors):
  """Removes non-zero entries from batched vectors.

  Requires that each vector have the same number of non-zero entries.

  Args:
    batch_of_vectors: A Tensor with length-N vectors, having shape [..., N].
  Returns:
    A Tensor with shape [..., M] where M is the number of non-zero entries in
    each vector.
  """
  nonzero_indices = tf.where(tf.not_equal(
      batch_of_vectors, tf.zeros_like(batch_of_vectors)))
  # gather_nd gives us a vector containing the non-zero entries of the
  # original Tensor
  nonzero_values = tf.gather_nd(batch_of_vectors, nonzero_indices)
  # Next, reshape so that all but the last dimension is the same as the input
  # Tensor. Note that this will fail unless each vector has the same number of
  # non-zero values.
  reshaped_nonzero_values = tf.reshape(
      nonzero_values,
      tf.concat([tf.shape(batch_of_vectors)[:-1], [-1]], axis=0))
  return reshaped_nonzero_values

t = tf.Variable(
    [[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.],
       [77., 0., 0., 12., 0., 0., 33., 55., 0.]],
      [[0., 132., 0., 0., 234., 0., 1., 24., 0.],
       [43., 0., 0., 124., 0., 0., 0., 52., 645]]]])
nonzero_t = batch_of_vectors_nonzero_entries(t)

with tf.Session():
    tf.global_variables_initializer().run()
    result_evaled = nonzero_t.eval()
    print(result_evaled.shape, result_evaled)

打印:

(1, 2, 2, 4) [[[[  2.35000000e+02   1.00600000e+03   2.30000000e+01   4.20000000e+01]
   [  7.70000000e+01   1.20000000e+01   3.30000000e+01   5.50000000e+01]]

  [[  1.32000000e+02   2.34000000e+02   1.00000000e+00   2.40000000e+01]
   [  4.30000000e+01   1.24000000e+02   5.20000000e+01   6.45000000e+02]]]]

如果结果变得粗糙,调查SparseTensor可能会很有用。