是否可以在Tensorflow中更新现有的文本分类模型?

时间:2018-08-30 12:58:08

标签: python tensorflow text-classification resuming-training

我是Python的新手,并且一直使用tensorflow执行文本分类。我想知道此文本分类模型是否可以用将来可能获取的每个新数据进行更新,这样我就不必从头开始训练模型。另外,有时随着时间的流逝,类的数量也可能会更多,因为我主要处理客户数据。通过使用现有检查点,是否可以使用包含更多类的数据来更新此现有文本分类模型?

1 个答案:

答案 0 :(得分:0)

鉴于您要问两个不同的问题,我现在分别回答这两个问题:

1)是,您可以使用已获取的新数据继续训练。这非常简单,您只需要像现在一样还原模型即可使用它。而不是运行诸如输出或预测之类的占位符,您应该运行优化器操作。 这将转换为以下代码:

model = build_model() # this is the function that build the model graph
saver = tf.train.Saver()
with tf.Session() as session:
    saver.restore(session, "/path/to/model.ckpt")
########### keep training #########
data_x, data_y = load_new_data(new_data_path)
for epoch in range(1, epochs+1):
    all_losses = list()
    num_batches = 0
    for b_x, b_y in batchify(data_x, data_y)
        _, loss = session.run([model.opt, model.loss], feed_dict={model.input:b_x, model.input_y : b_y}
        all_losses.append(loss * len(batch_x))
        num_batches += 1
    print("epoch %d - loss: %2f" % (epoch, sum(losses) / num_batches))

请注意,您现在需要由模型定义的操作的名称,以便运行优化器(model.opt)和损失op(model.loss)来训练和监视训练期间的损失。

2)如果要更改要使用的标签数量,则要复杂一些。如果您的网络是1层前馈,那么就没什么要做的,因为您需要更改矩阵维数,那么就需要从头开始重新训练一切。另一方面,例如,如果您具有多层网络(例如,进行分类的LSTM +密集层),则可以恢复旧模型的权重,而从头开始训练最后一层。为此,我建议您阅读此答案https://stackoverflow.com/a/41642426/4186749