我需要使用 tf.keras.*
实现简单的OCR模型。
但是:
(num_classes - 1)
。我想利用tf.nn.ctc_loss
,它有一个很好的论点:blank_index
。
因此,我做了一个简单的包装来计算CTC损失:
class CTCLossWrapper(tf.keras.losses.Loss):
def __init__(self, blank_class: int, reduction: str = tf.keras.losses.Reduction.AUTO, name: str = 'ctc_loss'):
super().__init__(reduction=reduction, name=name)
self.blank_class = blank_class
def call(self, y_true, y_pred):
output = y_true['output']
targets, target_lenghts = output['targets'], output['target_lengths']
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
max_input_len = K.cast(K.shape(y_pred)[1], dtype='int32')
input_lengths = tf.ones((K.shape(y_pred)[0]), dtype='int32') * max_input_len
return tf.nn.ctc_loss(
labels=targets,
logits=y_pred,
label_length=target_lenghts,
logit_length=input_lengths,
blank_index=self.blank_class
)
我还编写了一个简单的生成器函数,可以生成训练样本:
def generator(dataset, batch_size: int, shuffle=False):
indexes = np.arange(len(dataset))
while True:
if shuffle:
indexes = np.random.permutation(indexes)
for i in range(0, len(dataset), batch_size):
# Get next batch
batch = dataset[indexes[i:i+batch_size]]
images, image_widths = batch['images'], batch['image_widths']
targets, target_lengths = batch['targets'], batch['target_lengths']
# Re-arrange dimensions (B, H, W, C) -> (B, W, H, C)
# Important Note: width=W and height=H are swapped from typical Keras convention
# because width is the time dimension when it gets fed into the RNN
images = np.transpose(images, axes=(0, 2, 1, 3)).astype(np.float32) / 255.0
# Change zero target length to 1 due to invalid implementation of ctc_batch_cost in keras
target_lengths[target_lengths == 0] = 1
# Add singleton dimension
# image_widths = image_widths[:, np.newaxis]
# target_lengths = target_lengths[:, np.newaxis]
# Construct output value
outputs = {
'images': images, # (batch_size, max_image_width, 32, 1)
'image_widths': image_widths, # (batch_size,)
'targets': targets, # (batch_size, max_target_len)
'target_lengths': target_lengths, # (batch_size,)
}
yield images, dict(output=outputs)
您可能会看到,生成器不仅输出(x, y_true)
,还输出4个值:
之所以这样,是因为tf.nn.ctc_loss
还需要至少4个参数才能工作。
我的计划是将输入图像传递为x
,将所有4个值的字典传递为y_true
。
然后,我当然使用CTCLossWrapper
和blank_class
编译模型:
model.compile(
optimizer=Adam(),
loss=CTCLossWrapper(blank_class=blank_class),
)
之后,我可以通过以下方式开始训练:
model.fit(
x=generator(train_dataset, batch_size=batch_size, shuffle=True),
steps_per_epoch=int(len(train_dataset) // batch_size),
epochs=200
)
问题是,当我的CTCLossWrapper
被调用时,它不会得到dict()为y_true
。它仅从中获得张量之一。
如何避免或关闭张量流预处理并以与从数据集提供的相同形式获取y_true
值?