在使用" tf.nn.ctc_loss()"计算损失时,使用Tensorflow设置RNN模型。

时间:2017-08-09 14:03:48

标签: tensorflow

使用Tensorflow的内部方法" tf.nn.ctc_loss(标签,输入,sequence_length,preprocess_collapse_repeated = False,ctc_merge_repeated = True)"然而,为了计算损失,我发生了错误。

Caused by op 'CTCLoss', defined at:
  File "/home/liu/PythonCode/single_deepspeech/util/data_process.py", line 175, in <module>
    total_loss = tf.nn.ctc_loss(labels=result, inputs=logits, sequence_length=source_lengths)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/ctc_ops.py", line 145, in ctc_loss
    ctc_merge_repeated=ctc_merge_repeated)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_ctc_ops.py", line 164, in _ctc_loss
    name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2336, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1228, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Saw a non-null label (index >= num_classes - 1) following a null label, batch: 0 num_classes: 29 labels: 
     [[Node: CTCLoss = CTCLoss[ctc_merge_repeated=true, preprocess_collapse_repeated=false, _device="/job:localhost/replica:0/task:0/cpu:0"](Reshape_7, ToInt64, Gather, CTCLoss/sequence_length)]

标签是SparseTensor,

  1. indices.shape = [327,2]

  2. values.shape = [327]

  3. dense_shape.shape = [3130]

  4. 输入是RNN的输出:logit

    logit.shape = [447, 3, 29]
    

    sequence_length是RNN的输入sequence_len,

     sequence_length.shape=[408,432,494]
    

    这个问题困扰了我很多天,我们将不胜感激。

1 个答案:

答案 0 :(得分:0)

在CTC模型中,有一个值为“num_classes”的空白标签,如果该值或更高的值包含在标签中,您将收到此错误。检查您的标签,并记住单词之间的空格不是您的空白标签,而只是另一个类(普通标签)。