在Tensorflow中保存和恢复会话模型

时间:2017-06-03 22:12:01

标签: python session tensorflow

有人可以帮助我,我想在张量流中保存我的模型,以便以后用它来预测。

这是我保存会话时的代码:

# Initializing the variables
init = tf.global_variables_initializer()

# Create Saver
saver = tf.train.Saver()

# Merge all the summaries
summaries = tf.summary.merge_all()

total_batch = 5
with tf.Session() as sess:
    sess.run(init)

    writer = tf.summary.FileWriter(log_path, sess.graph)

    # Training cycle
    for epoch in range(training_epochs):
        # Loop over all batches
        for batch in range(total_batch):
            batch_x, batch_y = batching(data_train, batch_size=50)

            # Run optimization op (backprop) and cost op (to get loss value)
            curr_loss, cur_accuracy, _, summary = sess.run([cost, accuracy, optimizer, summaries], feed_dict={x: batch_x, 
                                                                                                              y: batch_y})
            writer.add_summary(_, batch)

        # Display logs per epoch step
        if epoch % display_step == 0:
            writer.add_summary(summary, epoch * total_batch + batch)
            # Print the loss
            print("Epoch: %04d/%d. Batch: %d/%d. Current loss: %.5f. Train Accuracy: %.3f"
                      %(epoch, training_epochs, batch, total_batch, curr_loss, cur_accuracy))            

    # Test the session
    test_accuracy, test_predictions = sess.run([accuracy, y_p], feed_dict={ x: X_test, 
                                                                            y: labels_test})

    print("Test Accuracy: %.3f" % test_accuracy)

    saved_value = saver.save(sess, 'model/heart_disease')

我尝试恢复会话时的代码:

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('model/heart_disease.meta')
  saver.restore(sess, tf.train.latest_checkpoint('/checkpoint'))

但是我收到以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-24-eea1c921000b> in <module>()
     18 
     19 with tf.Session() as sess:
---> 20   new_saver = tf.train.import_meta_graph('model/heart_disease.meta')
     21   saver.restore(sess, tf.train.latest_checkpoint('/checkpoint'))

C:\Users\dangz\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py in import_meta_graph(meta_graph_or_file, clear_devices, import_scope, **kwargs)
   1593                                       clear_devices=clear_devices,
   1594                                       import_scope=import_scope,
-> 1595                                       **kwargs)
   1596   if meta_graph_def.HasField("saver_def"):
   1597     return Saver(saver_def=meta_graph_def.saver_def, name=import_scope)

C:\Users\dangz\Anaconda3\lib\site-packages\tensorflow\python\framework\meta_graph.py in import_scoped_meta_graph(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name)
    497     importer.import_graph_def(
    498         input_graph_def, name=(import_scope or ""), input_map=input_map,
--> 499         producer_op_list=producer_op_list)
    500 
    501     # Restores all the other collections.

C:\Users\dangz\Anaconda3\lib\site-packages\tensorflow\python\framework\importer.py in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    306           node.op, [], output_types, name=node.name, attrs=node.attr,
    307           compute_shapes=False, compute_device=False,
--> 308           op_def=op_def)
    309 
    310     # 2. Add inputs to the operations.

C:\Users\dangz\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in create_op(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device)
   2337     if compute_shapes:
   2338       set_shapes_for_outputs(ret)
-> 2339     self._add_op(ret)
   2340     self._record_op_seen_by_control_dependencies(ret)
   2341 

C:\Users\dangz\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _add_op(self, op)
   2031       if op.name in self._nodes_by_name:
   2032         raise ValueError("cannot add op with name %s as that name "
-> 2033                          "is already used" % op.name)
   2034       self._nodes_by_id[op._id] = op
   2035       self._nodes_by_name[op.name] = op

ValueError: cannot add op with name layer1/biases/biases/Adam as that name is already used

0 个答案:

没有答案