说我有一个张量:
import tensorflow as tf
t = tf.Variable([[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.], [77., 0., 0., 12., 0., 0., 33., 55., 0.]],
[[0., 132., 0., 0., 234., 0., 1., 24., 0.], [43., 0., 0., 124., 0., 0., 0., 52., 645]]]])
我想省略零并留下一个形状的张量:(1,2,2,4),其中4是张量中非零元素的数量,如
t = tf.Variable([[[[235., 1006., 23., 42], [77., 12., 33., 55.]],
[[132., 234., 1., 24.], [43., 124., 52., 645]]]])
我使用布尔掩码在1-D张量上执行此操作。如何省略4-D张量中的零点。可以推广到更高级别吗?
答案 0 :(得分:2)
使用TensorFlow 1.12:
import tensorflow as tf
def batch_of_vectors_nonzero_entries(batch_of_vectors):
"""Removes non-zero entries from batched vectors.
Requires that each vector have the same number of non-zero entries.
Args:
batch_of_vectors: A Tensor with length-N vectors, having shape [..., N].
Returns:
A Tensor with shape [..., M] where M is the number of non-zero entries in
each vector.
"""
nonzero_indices = tf.where(tf.not_equal(
batch_of_vectors, tf.zeros_like(batch_of_vectors)))
# gather_nd gives us a vector containing the non-zero entries of the
# original Tensor
nonzero_values = tf.gather_nd(batch_of_vectors, nonzero_indices)
# Next, reshape so that all but the last dimension is the same as the input
# Tensor. Note that this will fail unless each vector has the same number of
# non-zero values.
reshaped_nonzero_values = tf.reshape(
nonzero_values,
tf.concat([tf.shape(batch_of_vectors)[:-1], [-1]], axis=0))
return reshaped_nonzero_values
t = tf.Variable(
[[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.],
[77., 0., 0., 12., 0., 0., 33., 55., 0.]],
[[0., 132., 0., 0., 234., 0., 1., 24., 0.],
[43., 0., 0., 124., 0., 0., 0., 52., 645]]]])
nonzero_t = batch_of_vectors_nonzero_entries(t)
with tf.Session():
tf.global_variables_initializer().run()
result_evaled = nonzero_t.eval()
print(result_evaled.shape, result_evaled)
打印:
(1, 2, 2, 4) [[[[ 2.35000000e+02 1.00600000e+03 2.30000000e+01 4.20000000e+01]
[ 7.70000000e+01 1.20000000e+01 3.30000000e+01 5.50000000e+01]]
[[ 1.32000000e+02 2.34000000e+02 1.00000000e+00 2.40000000e+01]
[ 4.30000000e+01 1.24000000e+02 5.20000000e+01 6.45000000e+02]]]]
如果结果变得粗糙,调查SparseTensor可能会很有用。