tensorflow如何从批处理输入中将非零值

时间:2019-04-02 07:48:37

标签: python tensorflow keras

第一个索引是None和批处理索引

在以下示例中,批处理大小为2(两行),输入长度为3

In [12]: ar = [[0,1,2],
    ...: [2,0,3]]

In [13]: mask = tf.greater(ar, 0)
    ...: non_zero_array = tf.boolean_mask(ar, mask)

In [14]: non_zero_array.eval(session=sess)
Out[14]: array([1, 2, 2, 3], dtype=int32)

我想要输出 [[1,2], [2,3]]代替[1,2,2,3](形状为[None,input_length])

我正在尝试自行实现mask_zero功能,因为一旦将mask_zero=True赋予嵌入层,就无法将其馈送到密集层(我将其他张量连接起来并压平然后馈送到密集层,Flatten不接受mask_zero

下面我得到item_average,它是prior_ids的平均嵌入,我想在不使用任何获取嵌入的情况下从0除去prior_idsmask_zero=0

 selected = self.item_embedding_layer(prior_ids)
 embedding_sum = tf.reduce_sum(selected, axis=1)
 non_zero_count =  tf.cast(tf.math.count_nonzero(prior_ids, axis=1), tf.float32)
 item_average = embedding_sum / tf.expand_dims(non_zero_count, axis=1)

1 个答案:

答案 0 :(得分:0)

这是一个可能的实现,它可以删除具有任意数量维数的张量的最后维中的零:

import tensorflow as tf

def remove_zeros(a):
    a = tf.convert_to_tensor(a)
    # Mask of selected elements
    mask = tf.not_equal(a, 0)
    # Take the first "row" of mask
    row0 = tf.gather_nd(mask, tf.zeros([tf.rank(mask) - 1], dtype=tf.int32))
    # Count number of non-zeros on last axis
    n = tf.math.count_nonzero(row0)
    # Mask elements
    a_masked = tf.boolean_mask(a, mask)
    # Reshape
    result = tf.reshape(a_masked, tf.concat([tf.shape(a)[:-1], [n]], axis=0))
    return result

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    print(sess.run(remove_zeros([[0, 1, 2],
                                 [2, 0, 3]])))
    # [[1 2]
    #  [2 3]]
    print(sess.run(remove_zeros([[[0, 1, 0], [2, 0, 0], [0, 3, 0]],
                                 [[0, 0, 4], [0, 5, 0], [6, 0, 0]]])))
    # [[[1]
    #   [2]
    #   [3]]
    # 
    #  [[4]
    #   [5]
    #   [6]]]