如何在tflearn中保存和恢复网络架构

时间:2017-09-08 11:23:58

标签: python tensorflow deep-learning tflearn

我正在寻找一种使用tflearn(tflearn.org)保存和恢复网络结构的方法。嵌入的model.save关注权重持久性。但是如果我想忘记网络结构并再次加载它,我必须再次构建网络。因此,我必须“记住”结构(层数,神经元数量)等。

我尝试过pickle(类似于pickle.dump(net))和tf.train.import_meta_graph(将tflearn.org与tensorflow结合使用)但没有成功。

假设我接受过这样的培训:

from future import division, print_function, absolute_import

import numpy as np
import tflearn

import tflearn.datasets.mnist as mnist

X, Y, testX, testY = mnist.load_data(one_hot=True)
X = np.reshape(X, (-1, 28, 28))

testX = np.reshape(testX, (-1, 28, 28))

net = tflearn.input_data(shape=[None, 28, 28])
net = tflearn.lstm(net, 128, return_seq=True)
net = tflearn.lstm(net, 128)
net = tflearn.regression(net, optimizer='adam',
loss='categorical_crossentropy', name="output1")
model = tflearn.DNN(net, tensorboard_verbose=2,tensorboard_dir='lstm')
model.fit(X, Y, n_epoch=1, validation_set=0.1, show_metric=True,
snapshot_step=100)

model.save('lstm_load.tflearn')

后来我想在不同的程序或会话中恢复所有需要的东西(当然伪代码不起作用; - ))

from future import division, print_function, absolute_import

import numpy as np
import tflearn

import tflearn.datasets.mnist as mnist

X, Y, testX, testY = mnist.load_data(one_hot=True)
X = np.reshape(X, (-1, 28, 28))

testX = np.reshape(testX, (-1, 28, 28))

model = model.load_everything("model.tfl")
prediction = model.predict([X])

有什么想法吗?

0 个答案:

没有答案
相关问题