如何解决Tensorflow 2中CTC丢失的“找不到有效设备”问题?

时间:2019-10-30 15:08:32

标签: python tensorflow

我想计算OCR问题的CTC损失,但是每当我运行代码时,它都会导致:

NotFoundError:找不到节点的有效设备。 节点:{{node OneHot}} 所有为op OneHot注册的内核:

我正在google colab上使用tensorflow 2。

下面是代码的关键部分:

total_count = records.limit(5).except(:limit).count

我希望得到计算得出的损失,但我得到了以下异常:

def calculate_ctc_loss(predictions, labels, label_length, logit_length):
    # shape of predictions (batch_size, max_label_seq_length) --> (64, 20)
    # shape of labels (batch_size, timeframes, dictionary_size) --> (64, 20, 30)
    label_length_tensor = tf.constant(label_length, shape=(labels.shape[0], 1)) 
    logit_length_tensor = tf.constant(logit_length, shape=(labels.shape[0], 1))
    # label_length is a scalar, here 20. the same for logit_length
    logits = tf.transpose(predictions, (1, 0, 2))
    loss = tf.nn.ctc_loss(labels, logits, logit_length_tensor, label_length_tensor)
    return loss

1 个答案:

答案 0 :(得分:0)

将标签类型更改为tf.int64/tf.int32