我正在寻找一种使用tflearn(tflearn.org)保存和恢复网络结构的方法。嵌入的model.save关注权重持久性。但是如果我想忘记网络结构并再次加载它,我必须再次构建网络。因此,我必须“记住”结构(层数,神经元数量)等。
我尝试过pickle(类似于pickle.dump(net))和tf.train.import_meta_graph(将tflearn.org与tensorflow结合使用)但没有成功。
假设我接受过这样的培训:
from future import division, print_function, absolute_import
import numpy as np
import tflearn
import tflearn.datasets.mnist as mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = np.reshape(X, (-1, 28, 28))
testX = np.reshape(testX, (-1, 28, 28))
net = tflearn.input_data(shape=[None, 28, 28])
net = tflearn.lstm(net, 128, return_seq=True)
net = tflearn.lstm(net, 128)
net = tflearn.regression(net, optimizer='adam',
loss='categorical_crossentropy', name="output1")
model = tflearn.DNN(net, tensorboard_verbose=2,tensorboard_dir='lstm')
model.fit(X, Y, n_epoch=1, validation_set=0.1, show_metric=True,
snapshot_step=100)
model.save('lstm_load.tflearn')
后来我想在不同的程序或会话中恢复所有需要的东西(当然伪代码不起作用; - ))
from future import division, print_function, absolute_import
import numpy as np
import tflearn
import tflearn.datasets.mnist as mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = np.reshape(X, (-1, 28, 28))
testX = np.reshape(testX, (-1, 28, 28))
model = model.load_everything("model.tfl")
prediction = model.predict([X])
有什么想法吗?