如何导入(恢复)由文件中的tflearn构建的神经网络模型

时间:2018-05-07 18:14:46

标签: python tensorflow neural-network text-classification tflearn

我指的是关于文本分类的this tutorial,并为文本分类构建了自定义训练集。

我使用以下代码保存模型。

# Define model and setup tensorboard
model = tflearn.DNN(net, tensorboard_dir='tflearn_logs')
# Start training (apply gradient descent algorithm)
model.fit(train_x, train_y, n_epoch=1000, batch_size=8, show_metric=True)
model.save('model.tflearn')

这将生成以下文件。

model.tflearn.data-00000-of-00001
model.tflearn.index
model.tflearn.meta
tflearn_logs folder

我想使用不同迭代中构建的模型进行测试。

我试过了,

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.tflearn.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

但是我得到了;

  

KeyError:“名称'adam'指的是不在图表中的操作。”错误

我知道from documentation tflearn.DNN(network).load('file_name')加载模型,但我们需要创建并传递网络实例,构建网络我们再次从头开始查看相同的代码,这需要时间,因为它会做我希望避免的训练。

构建网络代码

net = tflearn.input_data(shape=[None, len(train_x[0])])
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, len(train_y[0]), activation='softmax')
net = tflearn.regression(net)

tflearn.input_data将形状输入作为强制性输入,因此我们再次需要再次输入训练数据。因此它会导致重建模型。 我检查了文档,找不到我需要的东西(2-3行代码可以导入构建神经网络模型以节省再培训时间。

如果你们知道解决方案,请告诉我。

Similar question但不重复

  • OP在构建树时构建神经网络时面临问题,而我正面临导入构建模型的问题。
  • 答案中提到的教程没有tflearn NN模型导入

1 个答案:

答案 0 :(得分:0)

我能够使用以下代码恢复已保存的模型。

tflearn可以从保存的日志和模型文件中恢复模型。

创建与保存模型相同大小的虚拟神经网络

注意:您可能需要跟踪以前保存的模型的权重(输入培训的大小和相应的类别)

net = tflearn.input_data(shape=[None, train_x[0]])
net = tflearn.fully_connected(net, 8, restore=False)
net = tflearn.fully_connected(net, 8, restore=False)
net = tflearn.fully_connected(net, train_y[0], activation='softmax', restore=False)
dnn = tflearn.DNN(net, tensorboard_dir='tflearn_logs')

将保存的模型加载到DNN

model = dnn.load('./model.tflearn')
使用加载的模型进行预测
test_data = ###converted data 
model.predict(test_data)