我运行下面的代码,以便从给定的索引矩阵(words_chars_ids形状为(6,200,20))中获取填充矩阵。 结果的形状为(6,200,20,emb_size),其中输出中的每个条目都包含一个为1或0(具有emb_size大小)的张量。
我有两个问题:
有没有一种更优雅的方法来实现此方法(没有嵌套的map_fn)
性能似乎很慢-有没有更有效的方法来获得结果?
def get_padding_mask(words_chars_ids, emb_size):
padding_mask = tf.map_fn(
lambda x: tf.map_fn(
lambda y: tf.map_fn(
lambda z: tf.cond(tf.less(z, 1),
lambda: tf.zeros([emb_size, ], dtype=tf.int32),
lambda: tf.ones([emb_size, ], dtype=tf.int32)
),
y),
x),
words_chars_ids)
return padding_mask
答案 0 :(得分:1)
您可以简单地执行以下操作:
def get_padding_mask(words_chars_ids, emb_size):
mask = tf.dtypes.cast(words_chars_ids >= 1, tf.int32)
return tf.tile(tf.expand_dims(mask, -1), [1, 1, 1, emb_size])