有效的reduce_mean与序列长度信息

时间:2016-12-12 20:54:16

标签: tensorflow

我想这个通用代码始终有效:

loss = tf.nn.sparse_softmax_cross_entropy_with_logits(out, labels)
reduced_loss = tf.reduce_sum(tf.sequence_mask(seq_lens) * loss) / tf.sum(seq_lens)

有更有效的方法吗?

我现在这样做:

def flatten_with_seq_len_mask(x, seq_lens):
  """
  :param tf.Tensor x: shape (batch,time,...s...)
  :param tf.Tensor seq_lens: shape (batch,) of int64
  :return: tensor of shape (min(batch*time, sum(seq_len)), ...s...)
  :rtype: tf.Tensor
  """
  with tf.name_scope("flatten_with_seq_len_mask"):
    x = check_dim_equal(x, 0, seq_lens, 0)
    mask = tf.sequence_mask(seq_lens, maxlen=tf.shape(x)[1])  # shape (batch,time)
    return tf.boolean_mask(x, mask)

out_flat = flatten_with_seq_len_mask(out, seq_lens)
labels_flat = flatten_with_seq_len_mask(labels, seq_lens)
loss_flat = tf.nn.sparse_softmax_cross_entropy_with_logits(out_flat, labels_flat)
reduced_loss = tf.reduce_mean(loss_flat)

0 个答案:

没有答案