使用tflearn在其他过程中加载训练有素的模型

时间:2017-12-19 13:58:00

标签: python tensorflow tflearn

我正在尝试保存模型并使用tflearn库将其加载到其他进程中...

所以我生成了模型:

lenx = 21908
leny = 81
# Build neural network
net = tflearn.input_data(shape=[None, lenx])
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, leny, activation='softmax')
net = tflearn.regression(net)

# Define model and setup tensorboard
model = tflearn.DNN(net, tensorboard_dir='tflearn_logs')
# Start training (apply gradient descent algorithm)
model.fit(train_x, train_y, n_epoch=10, batch_size=8, show_metric=True)
model.save('model.tflearn')

那很好用! 然后在其他文件中,要在其他进程中运行,我试图以这种方式加载它:

lenx = 21908
leny = 81
# Build neural network
net = tflearn.input_data(shape=[None, lenx])
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, leny, activation='softmax')
net = tflearn.regression(net)

model = tflearn.DNN(net, tensorboard_dir='tflearn_logs')

model.load("model.tflearn")

但我收到了这个错误:

ValueError: Cannot feed value of shape (1, 0) for Tensor 'InputData/X:0', which has shape '(?, 21908)'

我尝试了很多东西,但它不起作用。

2 个答案:

答案 0 :(得分:0)

您将成为的神经网络架构 加载应相同。另外,您不必定义 最后一层,因为那里 无需培训。不要用线 net = tflearn.regression(net)

答案 1 :(得分:-1)

我认为你缺少的是" weights_only = True"加载中的参数:

model.load("model.tflearn", weights_only=True)