我想这个通用代码始终有效:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(out, labels)
reduced_loss = tf.reduce_sum(tf.sequence_mask(seq_lens) * loss) / tf.sum(seq_lens)
有更有效的方法吗?
我现在这样做:
def flatten_with_seq_len_mask(x, seq_lens):
"""
:param tf.Tensor x: shape (batch,time,...s...)
:param tf.Tensor seq_lens: shape (batch,) of int64
:return: tensor of shape (min(batch*time, sum(seq_len)), ...s...)
:rtype: tf.Tensor
"""
with tf.name_scope("flatten_with_seq_len_mask"):
x = check_dim_equal(x, 0, seq_lens, 0)
mask = tf.sequence_mask(seq_lens, maxlen=tf.shape(x)[1]) # shape (batch,time)
return tf.boolean_mask(x, mask)
out_flat = flatten_with_seq_len_mask(out, seq_lens)
labels_flat = flatten_with_seq_len_mask(labels, seq_lens)
loss_flat = tf.nn.sparse_softmax_cross_entropy_with_logits(out_flat, labels_flat)
reduced_loss = tf.reduce_mean(loss_flat)