TFlearn保存和加载模块没有网络初始化

时间:2017-11-30 04:27:33

标签: python tflearn rnn

您好我已经创建了一个模型并使用TFlearn

进行训练
# Network building
net = input_data(shape=[None, max_document_length])
net = embedding(net, input_dim=Vocabulary_size, output_dim=128)
net = simple_rnn(net, 512)
net = dropout(net, 0.5)
net = fully_connected(net, n_class, activation='softmax')
net = regression(net, optimizer='adam', loss='categorical_crossentropy')

# Training
model = tflearn.DNN(net, clip_gradients=0.)
model.fit(doc_vec, label, validation_set=0.1, show_metric=True, batch_size=100)

#saving the files
model.save('path')

但是当我尝试在新文件中恢复此模型时,我需要再次提供整个网络构建并加载模型并进行预测。

net = input_data(shape=[None, max_document_length])
net = embedding(net, input_dim=Vocabulary_size, output_dim=128)
net = simple_rnn(net, 512)
net = dropout(net, 0.5)
net = fully_connected(net, n_class, activation='softmax')
net = regression(net, optimizer='adam', loss='categorical_crossentropy')
model = tflearn.DNN(net, clip_gradients=0.)
model.load('path')
print(model.predict_label(input))

我试图腌制保存网和模型,但它抛出了一个错误。 那么我该怎样做才能确保我不提供整个网络,只需保存模型并恢复它并使用model.predict,就像使用sklearn一样。

0 个答案:

没有答案