将训练好的模型加载到张量流中

时间:2017-08-07 16:44:05

标签: tensorflow neural-network

我在Tensorflow中构建了一个模型,我已经训练过了。现在我想使用输出,所以我想将Checkpoint,Meta和所有其他文件加载到tensorlow。

我使用以下代码训练模型:

# Logging
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train')
test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')
validate_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/validate')
writer = tf.summary.FileWriter(FLAGS.summary_dir, sess.graph)
saver = tf.train.Saver()  # for storing the best network

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Best validation accuracy seen so far
bestValidation = -0.1

# Training loop
coord = tf.train.Coordinator() # coordinator for threads
threads = tf.train.start_queue_runners(coord = coord, sess=sess) # start queue thread

# Training loop
for i in range(FLAGS.maxIter):
    xTrain, yTrain = sess.run(data_batch)
    sess.run(train_step, feed_dict={x_data: xTrain, y_target: np.transpose([yTrain])})
    summary = sess.run(merged, feed_dict={x_data: xTrain, y_target: np.transpose([yTrain])})
    train_writer.add_summary(summary, i)
    if ((i + 1) % 10 == 0):
        print("Iteration:", i + 1, "/", FLAGS.maxIter)
        summary = sess.run(merged, feed_dict={x_data: dataTest.data, y_target: np.transpose([dataTest.target])})
        test_writer.add_summary(summary, i)
        currentValidation, summary = sess.run([accuracy, merged], feed_dict={x_data: dataTest.data,
                                                                             y_target: np.transpose(
                                                                                 [dataTest.target])})
    validate_writer.add_summary(summary, i)
    if (currentValidation > bestValidation and currentValidation <= 0.9):
        bestValidation = currentValidation
        saver.save(sess=sess, save_path=FLAGS.summary_dir + '/bestNetwork')
        print("\tbetter network stored,", currentValidation, ">", bestValidation)

coord.request_stop()  # ask threads to stop
coord.join(threads)  # wait for threads to stop

现在我想将模型加载回Tensorflow。我希望能够做一些事情:

  • 使用我已经为训练和测试数据集创建的输出。
  • 将新数据加载到模型中,然后能够使用相同的权重来生成新输出。

我已尝试使用以下代码将模型加载回tensorflow,但它不起作用:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(FLAGS.summary_dir + '/bestNetwork.meta')
    saver.restore(sess,tf.train.latest_checkpoint(FLAGS.summary_dir + '/checkpoint'))

运行代码时出现以下错误:

TypeError:预期字节,找到NoneType

正如我所说,我使用tf.train.import_meta_graph()函数加载上一节中的元图,然后使用检查点部分加载权重。那为什么这不起作用?

1 个答案:

答案 0 :(得分:0)

您将模型保存为bestNetwork。试试这个:

saver.restore(sess,tf.train.latest_checkpoint(FLAGS.summary_dir + '/**bestNetwork**'))