如何使用tf.nn.ctc_loss计算所有空白序列的CTC丢失?

时间:2017-10-09 18:22:38

标签: tensorflow

用手计算所有空白的序列的CTC损失是直截了当的。但是,我找不到使用tf.nn.ctc_loss API执行此操作的方法。我是否会遗漏某些内容,或tf.nn.ctc_loss实施缺少此功能?当批处理中的一些序列没有输出符号时,此功能是必需的。

I reported this on github,它已关闭但没有回答。

环境:tf版本1.3,CPU版本; python 3.5 / 3.6; Win10 / Ubuntu 16.04。

首先,我们从代码开始:

import tensorflow as tf
num_classes, batch_size, seq_len = 3, 1, 2
labels = tf.SparseTensor(indices=[[0,0]], values=[0], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))

tf.nn.ctc_loss按预期行事,并打印正确答案:1.09861231

问题一:

如何计算所有空白序列的ctc损失? tf.nn.ctc_loss API要求值< num_labels,所以我们没有办法实现它?如果我确实将上面示例中的值更改为num_classes - 1(保留的空白ID),则tf.nn.ctc_loss没有抱怨,并返回错误的答案:0.81093025!正确答案是2 * log(3)。重现问题一的代码如下:

import tensorflow as tf
num_classes, batch_size, seq_len = 3, 1, 2
labels = tf.SparseTensor(indices=[[0,0]], values=[2], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))

问题二:

让我们将序列长度更改为1,如下所示

import tensorflow as tf
num_classes, batch_size, seq_len = 3, 1, 1
labels = tf.SparseTensor(indices=[[0,0]], values=[2], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))

再次运行代码。这段代码在Ubuntu中给出了正确答案log(3),但在Win10中崩溃并显示消息:Kernel死了,重新启动。

1 个答案:

答案 0 :(得分:0)

我不了解TF v1.3,但在TF v2.0中,可以通过使用密集张量而不是稀疏来计算具有所有空白的序列的CTC损失。在TF v1.x中,我相信ctc_loss_v2具有相同的行为。

y_true = np.array([[0, 0, 0, 0]])
y_pred = np.zeros((1, 2, 3), np.float32)
input_length = np.array([2])
label_length = np.array([0])
cost = tf.nn.ctc_loss(y_true, y_pred, label_length, input_length)