运行CTC损失功能

时间:2020-08-18 07:04:38

标签: python tensorflow machine-learning keras neural-network

我想在莎士比亚数据集上尝试CTC损失函数,并且在计算损失时,预测的张量形状为(64,100,65),它与(64,100)的标签形状不匹配。所以我使用了一些数学运算来转换尺寸,但有错误。

代码

def loss(labels, logits):
  return tf.keras.losses.categorical_crossentropy(labels, logits)

example_batch_loss  = loss(labels=target_example_batch, logits=tf.math.argmax(tf.convert_to_tensor(example_batch_predictions), axis=-1, output_type=tf.int64))

错误

无法计算Mul作为输入#1(从零开始),应该是int64张量,但是是双张量[Op:Mul]

请帮助我找到使用CTC损失的解决方案。

1 个答案:

答案 0 :(得分:0)

您要输入模型输出的argmax,即输出值最高的索引。 CTC损失(与大多数损失函数一样)适用于logits,即模型产生的非标准化概率分布。因此,预测形状(64,100,65)而仅预测目标(64,100)并没有错。

但是请注意,当您的模型输出比目标长得多时,CTC可以处理各种情况。典型的用例是语音识别,其中您有大量的信号窗口与相对较少的音素相匹配。如果您的输出长度和目标长度相同,则CTC会退化为标准的交叉熵。

假设example_batch_predictions是模型输出,然​​后通过softmax将其标准化,那么您应该这样做:

example_batch_loss  = loss(labels=target_example_batch, logits=example_batch_predictions, axis=-1, output_type=tf.int64))