CTC损失下降并停止

时间:2018-03-26 03:10:56

标签: python neural-network computer-vision deep-learning pytorch

我正在尝试训练验证码识别模型。模型细节是resnet预训练CNN层+双向LSTM +完全连接。它在python库captcha生成的验证码上达到了90%的序列精度。问题是这些生成的验证码似乎具有相似的每个角色的位置。当我在字符之间随机添加空格时,模型不再起作用。所以我想知道LSTM在学习过程中学习分段吗?然后我尝试使用CTC损失。起初,损失很快就会下降。但它保持在16左右而后期没有显着下降。我尝试了不同层次的LSTM,不同数量的单位。 2层LSTM损耗较低,但仍未收敛。 3层就像2层。损失曲线:

enter image description here

#encoding:utf8
import os
import sys
import torch
import warpctc_pytorch
import traceback

import torchvision
from torch import nn, autograd, FloatTensor, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from tensorboard import SummaryWriter
from pprint import pprint

from net.utils import decoder

from logging import getLogger, StreamHandler
logger = getLogger(__name__)
handler = StreamHandler(sys.stdout)
logger.addHandler(handler)

from dataset_util.utils import id_to_character
from dataset_util.transform import rescale, normalizer
from config.config import MAX_CAPTCHA_LENGTH, TENSORBOARD_LOG_PATH, MODEL_PATH


class CNN_RNN(nn.Module):
    def __init__(self, lstm_bidirectional=True, use_ctc=True, *args, **kwargs):
        super(CNN_RNN, self).__init__(*args, **kwargs)
        model_conv = torchvision.models.resnet18(pretrained=True)
        for param in model_conv.parameters():
            param.requires_grad = False

        modules = list(model_conv.children())[:-1]  # delete the last fc layer.
        for param in modules[8].parameters():
            param.requires_grad = True

        self.resnet = nn.Sequential(*modules)            # CNN with fixed parameters from resnet as feature extractor
        self.lstm_input_size = 512 * 2 * 2
        self.lstm_hidden_state_size = 512
        self.lstm_num_layers = 2
        self.chracter_space_length = 64
        self._lstm_bidirectional = lstm_bidirectional
        self._use_ctc = use_ctc
        if use_ctc:
            self._max_captcha_length = int(MAX_CAPTCHA_LENGTH * 2)
        else:
            self._max_captcha_length = MAX_CAPTCHA_LENGTH

        if lstm_bidirectional:
            self.lstm_hidden_state_size = self.lstm_hidden_state_size * 2           # so that hidden size for one direction in bidirection lstm is the same as vanilla lstm
            self.lstm = self.lstm = nn.LSTM(self.lstm_input_size, self.lstm_hidden_state_size // 2, dropout=0.5, bidirectional=True, num_layers=self.lstm_num_layers)
        else:
            self.lstm = nn.LSTM(self.lstm_input_size, self.lstm_hidden_state_size, dropout=0.5, bidirectional=False, num_layers=self.lstm_num_layers)  # dropout doen't work for one layer lstm

        self.ouput_to_tag = nn.Linear(self.lstm_hidden_state_size, self.chracter_space_length)
        self.tensorboard_writer = SummaryWriter(TENSORBOARD_LOG_PATH)
        # self.dropout_lstm = nn.Dropout()


    def init_hidden_status(self, batch_size):
        if self._lstm_bidirectional:
            self.hidden = (autograd.Variable(torch.zeros((self.lstm_num_layers * 2, batch_size, self.lstm_hidden_state_size // 2))),
                           autograd.Variable(torch.zeros((self.lstm_num_layers * 2, batch_size, self.lstm_hidden_state_size // 2)))) # number of layers, batch size, hidden dimention
        else:
            self.hidden = (autograd.Variable(torch.zeros((self.lstm_num_layers, batch_size, self.lstm_hidden_state_size))),
                           autograd.Variable(torch.zeros((self.lstm_num_layers, batch_size, self.lstm_hidden_state_size)))) # number of layers, batch size, hidden dimention


    def forward(self, image):
        '''
        :param image:  # batch_size, CHANNEL, HEIGHT, WIDTH
        :return:
        '''
        features = self.resnet(image)                 # [batch_size, 512, 2, 2]
        batch_size = image.shape[0]
        features = [features.view(batch_size, -1) for i in range(self._max_captcha_length)]
        features = torch.stack(features)
        self.init_hidden_status(batch_size)
        output, hidden = self.lstm(features, self.hidden)
        # output = self.dropout_lstm(output)
        tag_space = self.ouput_to_tag(output.view(-1, output.size(2)))      # [MAX_CAPTCHA_LENGTH * BATCH_SIZE, CHARACTER_SPACE_LENGTH]
        tag_space = tag_space.view(self._max_captcha_length, batch_size, -1)

        if not self._use_ctc:
            tag_score = F.log_softmax(tag_space, dim=2)             # [MAX_CAPTCHA_LENGTH, BATCH_SIZE, CHARACTER_SPACE_LENGTH]
        else:
            tag_score = tag_space

        return tag_score


    def train_net(self, data_loader, eval_data_loader=None, learning_rate=0.008, epoch_num=400):
        try:
            if self._use_ctc:
                loss_function = warpctc_pytorch.warp_ctc.CTCLoss()
            else:
                loss_function = nn.NLLLoss()

            # optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.parameters()), momentum=0.9, lr=learning_rate)
            # optimizer = MultiStepLR(optimizer, milestones=[10,15], gamma=0.5)

            # optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()))
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()))
            self.tensorboard_writer.add_scalar("learning_rate", learning_rate)

            tensorbard_global_step=0
            if os.path.exists(os.path.join(TENSORBOARD_LOG_PATH, "resume_step")):
                with open(os.path.join(TENSORBOARD_LOG_PATH, "resume_step"), "r") as file_handler:
                    tensorbard_global_step = int(file_handler.read()) + 1

            for epoch_index, epoch in enumerate(range(epoch_num)):
                for index, sample in enumerate(data_loader):
                    optimizer.zero_grad()
                    input_image = autograd.Variable(sample["image"])        # batch_size, 3, 255, 255
                    tag_score = self.forward(input_image)

                    if self._use_ctc:
                        tag_score, target, tag_score_sizes, target_sizes = self._loss_preprocess_ctc(tag_score, sample)
                        loss = loss_function(tag_score, target, tag_score_sizes, target_sizes)
                        loss = loss / tag_score.size(1)

                    else:
                        target = sample["padded_label_idx"]
                        tag_score, target = self._loss_preprocess(tag_score, target)
                        loss = loss_function(tag_score, target)

                    print("Training loss: {}".format(float(loss)))
                    self.tensorboard_writer.add_scalar("training_loss", float(loss), tensorbard_global_step)
                    loss.backward()
                    optimizer.step()

                    if index % 250 == 0:
                        print(u"Processing batch: {} of {}, epoch: {}".format(index, len(data_loader), epoch_index))
                        self.evaluate(eval_data_loader, loss_function, tensorbard_global_step)

                    tensorbard_global_step += 1

                self.save_model(MODEL_PATH + "_epoch_{}".format(epoch_index))

        except KeyboardInterrupt:
            print("Exit for KeyboardInterrupt, save model")
            self.save_model(MODEL_PATH)

            with open(os.path.join(TENSORBOARD_LOG_PATH, "resume_step"), "w") as file_handler:
                file_handler.write(str(tensorbard_global_step))

        except Exception as excp:
            logger.error(str(excp))
            logger.error(traceback.format_exc())


    def predict(self, image):
        # TODO ctc version
        '''
        :param image: [batch_size, channel, height, width]
        :return:
        '''
        tag_score = self.forward(image)
        # TODO ctc
        # if self._use_ctc:
        #     tag_score = F.softmax(tag_score, dim=-1)
        #     decoder.decode(tag_score)

        confidence_log_probability, indexes = tag_score.max(2)

        predicted_labels = []
        for batch_index in range(indexes.size(1)):
            label = ""
            for character_index in range(self._max_captcha_length):
                if int(indexes[character_index, batch_index]) != 1:
                    label += id_to_character[int(indexes[character_index, batch_index])]
            predicted_labels.append(label)

        return predicted_labels, tag_score


    def predict_pil_image(self, pil_image):
        try:
            self.eval()
            processed_image = normalizer(rescale({"image": pil_image}))["image"].view(1, 3, 255, 255)
            result, tag_score = self.predict(processed_image)
            self.train()

        except Exception as excp:
            logger.error(str(excp))
            logger.error(traceback.format_exc())
            return [""], None

        return result, tag_score


    def evaluate(self, eval_dataloader, loss_function, step=0):
        total = 0
        sequence_correct = 0
        character_correct = 0
        character_total = 0
        loss_total = 0
        batch_size = eval_data_loader.batch_size
        true_predicted = {}
        self.eval()
        for sample in eval_dataloader:
            total += batch_size
            input_images = sample["image"]
            predicted_labels, tag_score = self.predict(input_images)

            for predicted, true_label in zip(predicted_labels, sample["label"]):
                if predicted == true_label:                  # dataloader is making label a list, use batch_size=1
                    sequence_correct += 1

                for index, true_character in enumerate(true_label):
                    character_total += 1
                    if index < len(predicted) and predicted[index] == true_character:
                        character_correct += 1

                true_predicted[true_label] = predicted

            if self._use_ctc:
                tag_score, target, tag_score_sizes, target_sizes = self._loss_preprocess_ctc(tag_score, sample)
                loss_total += float(loss_function(tag_score, target, tag_score_sizes, target_sizes) / batch_size)

            else:
                tag_score, target = self._loss_preprocess(tag_score, sample["padded_label_idx"])
                loss_total += float(loss_function(tag_score, target))  # averaged over batch index

        print("True captcha to predicted captcha: ")
        pprint(true_predicted)
        self.tensorboard_writer.add_text("eval_ture_to_predicted", str(true_predicted), global_step=step)

        accuracy = float(sequence_correct) / total
        avg_loss = float(loss_total) / (total / batch_size)
        character_accuracy = float(character_correct) / character_total
        self.tensorboard_writer.add_scalar("eval_sequence_accuracy", accuracy, global_step=step)
        self.tensorboard_writer.add_scalar("eval_character_accuracy", character_accuracy, global_step=step)
        self.tensorboard_writer.add_scalar("eval_loss", avg_loss, global_step=step)
        self.zero_grad()
        self.train()


    def _loss_preprocess(self, tag_score, target):
        '''
        :param tag_score:  value return by self.forward
        :param target:     sample["padded_label_idx"]
        :return:           (processed_tag_score, processed_target)  ready for NLLoss function
        '''
        target = target.transpose(0, 1)
        target = target.contiguous()
        target = target.view(target.size(0) * target.size(1))
        tag_score = tag_score.view(-1, self.chracter_space_length)

        return tag_score, target


    def _loss_preprocess_ctc(self, tag_score, sample):
        target_2d = [
            [int(ele) for ele in sample["padded_label_idx"][row, :] if int(ele) != 0 and int(ele) != 1]
            for row in range(sample["padded_label_idx"].size(0))]
        target = []
        for ele in target_2d:
            target.extend(ele)
        target = autograd.Variable(torch.IntTensor(target))

        # tag_score = F.softmax(F.sigmoid(tag_score), dim=-1)
        tag_score_sizes = autograd.Variable(torch.IntTensor([self._max_captcha_length] * tag_score.size(1)))
        target_sizes = autograd.Variable(sample["captcha_length"].int())

        return tag_score, target, tag_score_sizes, target_sizes


    # def visualize_graph(self, dataset):
    #     '''Since pytorch use dynamic graph, an input is required to visualize graph in tensorboard'''
    #     # warning: Do not run this, the graph is too large to visualize...
    #     sample = dataset[0]
    #     input_image = autograd.Variable(sample["image"].view(1, 3, 255, 255))
    #     tag_score = self.forward(input_image)
    #     self.tensorboard_writer.add_graph(self, tag_score)


    def save_model(self, model_path):
        self.tensorboard_writer.close()
        self.tensorboard_writer = None          # can't be pickled
        torch.save(self, model_path)
        self.tensorboard_writer = SummaryWriter(TENSORBOARD_LOG_PATH)


    @classmethod
    def load_model(cls, model_path=MODEL_PATH, *args, **kwargs):
        net = cls(*args, **kwargs)
        if os.path.exists(model_path):
            model = torch.load(model_path)
            if model:
                model.tensorboard_writer = SummaryWriter(TENSORBOARD_LOG_PATH)
                net = model

        return net


    def __del__(self):
        if self.tensorboard_writer:
            self.tensorboard_writer.close()


if __name__ == "__main__":
    from dataset_util.dataset import dataset, eval_dataset
    data_loader = DataLoader(dataset, batch_size=2, shuffle=True)
    eval_data_loader = DataLoader(eval_dataset, batch_size=2, shuffle=True)

    net = CNN_RNN.load_model()

    net.train_net(data_loader, eval_data_loader=eval_data_loader)
    # net.predict(dataset[0]["image"].view(1, 3, 255, 255))

    # predict_pil_image test code
    # from config.config import IMAGE_PATHS
    # import glob
    # from PIL import Image
    #
    # image_paths = glob.glob(os.path.join(IMAGE_PATHS.get("EVAL"), "*.png"))
    # for image_path in image_paths:
    #     pil_image = Image.open(image_path)
    #     predicted, score = net.predict_pil_image(pil_image)
    #     print("True value: {}, predicted: {}".format(os.path.split(image_path)[1], predicted))

    print("Done")

以上代码是主要部分。如果您需要其他组件使其运行,请发表评论。在这里停留了很长时间。任何有关培训crnn + ctc的建议都表示赞赏。

3 个答案:

答案 0 :(得分:1)

我一直在用 ctc loss 进行训练,遇到了同样的问题。我知道这是一个相当晚的答案,但希望它会帮助正在研究此问题的其他人。经过反复试验和大量研究,在使用 ctc 进行训练时,有几件事值得了解(如果您的模型设置正确):

  1. 该模型降低成本的最快方法是仅预测空白。这在一些论文和博客中有所说明:参见 http://www.tbluche.com/ctc_and_blank.html
  2. 该模型首先学习仅预测空白,然后开始接收与正确底层标签相关的错误信号。这也在上面的链接中进行了解释。在实践中,我注意到我的模型在几百个 epoch 后开始学习真正的底层标签/目标,并且损失再次开始急剧下降。类似于此处显示的玩具示例:https://thomasmesnard.github.io/files/CTC_Poster_Mesnard_Auvolat.pdf
  3. 这些参数对您的模型是否收敛有很大影响 - 学习率、批量大小和纪元数。

答案 1 :(得分:0)

你有几个问题,所以我会尝试逐一回答。

首先,为什么在验证码中添加空格会破坏模型?

神经网络学会处理它所训练的数据。如果更改数据的分布(例如,在字符之间添加空格),则无法保证网络会一般化。正如你在提问中暗示的那样。您训练的验证码可能始终具有相同位置或相同距离的角色,因此您的模型可以学习并通过查看这些位置来学习利用它。如果您希望网络推广特定方案,则应明确针对该方案进行培训。因此,在您的情况下,您还应该在训练期间添加随机空格。

其次,为什么损失不会低于16?

显然,由于你的训练损失在16时也停滞不前(就像你的验证损失一样),问题在于你的模型根本没有能力处理问题的复杂性。换句话说,你的模型是不合适的。您有正确的反应来尝试增加网络容量。您试图增加LSTM的容量,但它没有帮助。因此,下一个合乎逻辑的步骤是网络的卷积部分不够强大。所以这里有一些你可能想要尝试的事情,从我认为最有可能成功到最不可能的事情:

  1. 让convnet成为可训练的:我注意到你正在使用一个预训练的回旋网,并且你没有微调那个回旋网的权重。这可能是个问题。无论您的培训是什么,它都可能无法开发出处理验证码所需的功能。你也应该尝试学习convnet的权重,以便为验证码开发有用的功能。

  2. 使用更深入的网络:这是天真的事情。你的小天鹅没有足够好的功能,尝试更强大的功能。 (但是,只有在你已经训练了这个网站之后你才能使用它。)

答案 2 :(得分:0)

根据我的经验,训练具有CTC损失的RNN模型并不是一项简单的任务。如果没有仔细设置培训,模型可能根本不会收敛。以下是我的建议:

  1. 检查培训中的CTC损失输出。对于模型会收敛,每批的CTC损失显着波动。如果您发现CTC损失几乎单调收缩到稳定值,那么该模型很可能会陷入局部最小值
  2. 使用短样本预训练您的模型。虽然我们已经推进了LSTMGRU等高级RNN结构,但仍然难以向后传播RNN以进行长时间的步骤。
  3. 放大样品种类。您甚至可以添加人工样本来帮助您的模型逃避来自本地最小值。
  4. F.Y.I。,我们刚开源了一个新的深度学习框架Dandelion,它具有内置的CTC目标,界面与pytorch非常相似。您可以使用蒲公英尝试您的模型,并将其与您当前的实施进行比较。