使用Tensorflow的内部方法" tf.nn.ctc_loss(标签,输入,sequence_length,preprocess_collapse_repeated = False,ctc_merge_repeated = True)"然而,为了计算损失,我发生了错误。
Caused by op 'CTCLoss', defined at:
File "/home/liu/PythonCode/single_deepspeech/util/data_process.py", line 175, in <module>
total_loss = tf.nn.ctc_loss(labels=result, inputs=logits, sequence_length=source_lengths)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/ctc_ops.py", line 145, in ctc_loss
ctc_merge_repeated=ctc_merge_repeated)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_ctc_ops.py", line 164, in _ctc_loss
name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2336, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1228, in __init__
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): Saw a non-null label (index >= num_classes - 1) following a null label, batch: 0 num_classes: 29 labels:
[[Node: CTCLoss = CTCLoss[ctc_merge_repeated=true, preprocess_collapse_repeated=false, _device="/job:localhost/replica:0/task:0/cpu:0"](Reshape_7, ToInt64, Gather, CTCLoss/sequence_length)]
标签是SparseTensor,
indices.shape = [327,2]
values.shape = [327]
dense_shape.shape = [3130]
输入是RNN的输出:logit
logit.shape = [447, 3, 29]
sequence_length是RNN的输入sequence_len,
sequence_length.shape=[408,432,494]
这个问题困扰了我很多天,我们将不胜感激。
答案 0 :(得分:0)
在CTC模型中,有一个值为“num_classes”的空白标签,如果该值或更高的值包含在标签中,您将收到此错误。检查您的标签,并记住单词之间的空格不是您的空白标签,而只是另一个类(普通标签)。