嵌套的tf.map_fn性能较慢

时间:2019-12-18 15:47:32

标签: python tensorflow

我运行下面的代码,以便从给定的索引矩阵(words_chars_ids形状为(6,200,20))中获取填充矩阵。 结果的形状为(6,200,20,emb_size),其中输出中的每个条目都包含一个为1或0(具有emb_size大小)的张量。

我有两个问题:

  1. 有没有一种更优雅的方法来实现此方法(没有嵌套的map_fn)

  2. 性能似乎很慢-有没有更有效的方法来获得结果?

def get_padding_mask(words_chars_ids, emb_size):

    padding_mask = tf.map_fn(
        lambda x: tf.map_fn(
            lambda y: tf.map_fn(
                lambda z: tf.cond(tf.less(z, 1),
                                  lambda: tf.zeros([emb_size, ], dtype=tf.int32),
                                  lambda: tf.ones([emb_size, ], dtype=tf.int32)
                                  ),
                y),
            x),
        words_chars_ids)
    return padding_mask

1 个答案:

答案 0 :(得分:1)

您可以简单地执行以下操作:

def get_padding_mask(words_chars_ids, emb_size):
    mask = tf.dtypes.cast(words_chars_ids >= 1, tf.int32)
    return tf.tile(tf.expand_dims(mask, -1), [1, 1, 1, emb_size])