我试图在RNN模型中使用tf.contrib.seq2seq.sequence_loss函数来计算损失。 根据API文档,此函数至少需要三个参数:logits,targets和weights
sequence_loss(
logits,
targets,
weights,
average_across_timesteps=True,
average_across_batch=True,
softmax_loss_function=None,
name=None
)
logits: A Tensor of shape [batch_size, sequence_length, num_decoder_symbols] and dtype float. The logits correspond to the prediction across all classes at each timestep.
targets: A Tensor of shape [batch_size, sequence_length] and dtype int. The target represents the true class at each timestep.
weights: A Tensor of shape [batch_size, sequence_length] and dtype float. weights constitutes the weighting of each prediction in the sequence. When using weights as masking, set all valid timesteps to 1 and all padded timesteps to 0, e.g. a mask returned by tf.sequence_mask.
average_across_timesteps: If set, sum the cost across the sequence dimension and divide the cost by the total label weight across timesteps.
average_across_batch: If set, sum the cost across the batch dimension and divide the returned cost by the batch size.
softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). Note that to avoid confusion, it is required for the function to accept named arguments.
name: Optional name for this operation, defaults to "sequence_loss".
我的理解是logits是我使用Xw + b后的预测,所以它的形状应该是[batch_size,sequence_length,output size]。然后target应该是我的标签,但是所需的形状是[batch_size,sequence_length]。我想我的标签应该与logits具有相同的形状。
那么如何将3d标签转换为2d?提前致谢
答案 0 :(得分:2)
targets
(标签)不需要与logits
形状相同。
如果我们暂时忽略batch_size(与你的问题无关),这个API只是通过每个单词的加权和损失来计算两个序列之间的损失。假设vocab_size是5,我们得到一个目标词3,{{1}用矢量[0.2,0.1,0.15,0.4,0.15]提供该目标的预测
为了计算目标和预测之间的损失,目标不需要与预测相同的形状为[0,0,0,1,0]。 tensorflow将在内部完成此任务
您可以参考两个api之间的区别:logits
和softmax_cross_entropy_with_logits
答案 1 :(得分:1)
您的标签应为2d形状矩阵[batch_size,sequence_length],并且您的logits应为3d张量的形状[batch_size,sequence_length,output_size]。因此,如果您的标签变量已经处于形状[batch_size,sequence_length],则无需扩展标签的尺寸。
如果您确实要扩展维度,可以像expended_variable = tf.expand_dims(the_variable_you_wanna_expand, axis = -1)