这是我的训练任务的时间线,它表明op tf.tile()在cpu上运行,占总时间的1/3。我想优化它以加快训练速度。 timeline from tracing log
with tf.name_scope("key_masking"):
key_masks = tf.sequence_mask(keys_length, tf.shape(keys)[1]) # (N, T_k)
key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k)
key_masks = tf.tile(tf.expand_dims(key_masks, 1),
[1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k)