我想在莎士比亚数据集上尝试CTC损失函数,并且在计算损失时,预测的张量形状为(64,100,65),它与(64,100)的标签形状不匹配。所以我使用了一些数学运算来转换尺寸,但有错误。
代码
def loss(labels, logits):
return tf.keras.losses.categorical_crossentropy(labels, logits)
example_batch_loss = loss(labels=target_example_batch, logits=tf.math.argmax(tf.convert_to_tensor(example_batch_predictions), axis=-1, output_type=tf.int64))
错误
无法计算Mul作为输入#1(从零开始),应该是int64张量,但是是双张量[Op:Mul]
请帮助我找到使用CTC损失的解决方案。
答案 0 :(得分:0)
您要输入模型输出的argmax,即输出值最高的索引。 CTC损失(与大多数损失函数一样)适用于logits,即模型产生的非标准化概率分布。因此,预测形状(64,100,65)而仅预测目标(64,100)并没有错。
但是请注意,当您的模型输出比目标长得多时,CTC可以处理各种情况。典型的用例是语音识别,其中您有大量的信号窗口与相对较少的音素相匹配。如果您的输出长度和目标长度相同,则CTC会退化为标准的交叉熵。
假设example_batch_predictions
是模型输出,然后通过softmax将其标准化,那么您应该这样做:
example_batch_loss = loss(labels=target_example_batch, logits=example_batch_predictions, axis=-1, output_type=tf.int64))