我想计算OCR问题的CTC损失,但是每当我运行代码时,它都会导致:
NotFoundError:找不到节点的有效设备。 节点:{{node OneHot}} 所有为op OneHot注册的内核:
我正在google colab上使用tensorflow 2。
下面是代码的关键部分:
total_count = records.limit(5).except(:limit).count
我希望得到计算得出的损失,但我得到了以下异常:
def calculate_ctc_loss(predictions, labels, label_length, logit_length):
# shape of predictions (batch_size, max_label_seq_length) --> (64, 20)
# shape of labels (batch_size, timeframes, dictionary_size) --> (64, 20, 30)
label_length_tensor = tf.constant(label_length, shape=(labels.shape[0], 1))
logit_length_tensor = tf.constant(logit_length, shape=(labels.shape[0], 1))
# label_length is a scalar, here 20. the same for logit_length
logits = tf.transpose(predictions, (1, 0, 2))
loss = tf.nn.ctc_loss(labels, logits, logit_length_tensor, label_length_tensor)
return loss
答案 0 :(得分:0)
将标签类型更改为tf.int64/tf.int32