有一个使用Keras的OCR实现,效果很好。
我现在正尝试在tensorflow中重新实现它,请参见代码Here。它们具有相同的模型定义:
在训练时,使用ctc_loss。他们都使用adam优化器来最大程度地减少损失(尽管tensorflow和keras的adam实现方式有所不同)
结果是,keras版本的ctc_loss会减少,而tf版本不会。
Keras损失和优化器定义(在keras_train.py中):
def ctc_lambda_func(args):
base_output, labels, label_length = args
base_output_shape = tf.shape(base_output)
sequence_length = tf.fill([base_output_shape[0],], base_output_shape[1])
return K.ctc_batch_cost(labels, base_output, sequence_length, label_length)
loss = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([base_output, labels, label_length])
train_model = Model(inputs=[base_input, labels, label_length], outputs=loss)
def loss_func(y_true, y_pred):
return y_pred
train_model.compile(loss={'ctc': loss_func}, optimizer='adam', metrics=['accuracy'])
tensorflow损失和优化器定义(在tf_train.py中):
self.loss = tf.nn.ctc_loss(labels=self.labels,
inputs=self.logits,
sequence_length=self.seq_len,
time_major=False)
self.cost = tf.reduce_mean(self.loss)
self.optimizer = tf.train.AdamOptimizer(learning_rate=Config.LEARNING_RATE)
self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
我也将adam更改为不同的优化器,结果是keras版本总是减少,而tenserflow版本则从来没有。
有人可以帮我吗?非常感谢。
答案 0 :(得分:0)
我有相同的问题,并发现ctc_batch_cost是线索。 K.ctc_batch_cost只是包装器,用于使用tf.ctc_loss并按如下所示转换softmax y_pred 在代码中 添加此选项后,此问题已解决。
y_pred = math_ops.log(array_ops.transpose(y_pred,perm = [1,0,2])+ epsilon())
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.
Arguments:
y_true: tensor `(samples, max_string_length)`
containing the truth labels.
y_pred: tensor `(samples, time_steps, num_categories)`
containing the prediction, or output of the softmax.
input_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_pred`.
label_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_true`.
Returns:
Tensor with shape (samples,1) containing the
CTC loss of each element.
"""
label_length = math_ops.cast(
array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
input_length = math_ops.cast(
array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
sparse_labels = math_ops.cast(
ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
return array_ops.expand_dims(
ctc.ctc_loss(
inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)