如何将y_true作为dict传递给自定义损失函数?

时间:2020-06-30 11:32:49

标签: python tensorflow keras customization loss-function

我需要使用 tf.keras.* 实现简单的OCR模型。 但是:

  • 空白类不像tf期望的那样为零,而是(num_classes - 1)
  • 事先不知道输入图像的宽度(不同批次的宽度不同)。

我想利用tf.nn.ctc_loss,它有一个很好的论点:blank_index。 因此,我做了一个简单的包装来计算CTC损失:

    class CTCLossWrapper(tf.keras.losses.Loss):
        def __init__(self, blank_class: int, reduction: str = tf.keras.losses.Reduction.AUTO, name: str = 'ctc_loss'):
            super().__init__(reduction=reduction, name=name)
            self.blank_class = blank_class
        
        def call(self, y_true, y_pred):
            output = y_true['output']
            targets, target_lenghts = output['targets'], output['target_lengths']
            y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + K.epsilon())
            max_input_len = K.cast(K.shape(y_pred)[1], dtype='int32')
            input_lengths = tf.ones((K.shape(y_pred)[0]), dtype='int32') * max_input_len
            return tf.nn.ctc_loss(
                labels=targets,
                logits=y_pred,
                label_length=target_lenghts,
                logit_length=input_lengths,
                blank_index=self.blank_class
            )

我还编写了一个简单的生成器函数,可以生成训练样本:

def generator(dataset, batch_size: int, shuffle=False):
    indexes = np.arange(len(dataset))
    while True:
        if shuffle:
            indexes = np.random.permutation(indexes)
        for i in range(0, len(dataset), batch_size):
            # Get next batch
            batch = dataset[indexes[i:i+batch_size]]
            images, image_widths = batch['images'], batch['image_widths']
            targets, target_lengths = batch['targets'], batch['target_lengths']
            # Re-arrange dimensions (B, H, W, C) -> (B, W, H, C)
            # Important Note: width=W and height=H are swapped from typical Keras convention
            # because width is the time dimension when it gets fed into the RNN
            images = np.transpose(images, axes=(0, 2, 1, 3)).astype(np.float32) / 255.0
            # Change zero target length to 1 due to invalid implementation of ctc_batch_cost in keras
            target_lengths[target_lengths == 0] = 1
            # Add singleton dimension
            # image_widths = image_widths[:, np.newaxis]
            # target_lengths = target_lengths[:, np.newaxis]
            # Construct output value
            outputs = {
                'images': images,  # (batch_size, max_image_width, 32, 1)
                'image_widths': image_widths,  # (batch_size,)
                'targets': targets,  # (batch_size, max_target_len)
                'target_lengths': target_lengths,  # (batch_size,)
            }
            yield images, dict(output=outputs)

您可能会看到,生成器不仅输出(x, y_true),还输出4个值:

  • 输入图像,
  • 输入图像宽度
  • 目标序列,
  • 每个靶序列的长度。

之所以这样,是因为tf.nn.ctc_loss还需要至少4个参数才能工作。

我的计划是将输入图像传递为x,将所有4个值的字典传递为y_true

然后,我当然使用CTCLossWrapperblank_class编译模型:

    model.compile(
        optimizer=Adam(),
        loss=CTCLossWrapper(blank_class=blank_class),
    )

之后,我可以通过以下方式开始训练:

    model.fit(
        x=generator(train_dataset, batch_size=batch_size, shuffle=True),
        steps_per_epoch=int(len(train_dataset) // batch_size),
        epochs=200
    )

问题是,当我的CTCLossWrapper被调用时,它不会得到dict()为y_true。它仅从中获得张量之一。

如何避免或关闭张量流预处理并以与从数据集提供的相同形式获取y_true值?

0 个答案:

没有答案