恢复Tensorflow检查点文件时出错

时间:2015-12-16 16:39:47

标签: python tensorflow

在tensorflow中使用saver.restore()方法时出现以下错误。知道为什么会这样吗?

我保存了这样的模型: saver.save(sess, checkpoint_path, global_step=step)

错误是:

tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
     [[Node: Variable_1/initial_value = Const[dtype=DT_FLOAT, value=Tensor<type: float shape: [] values: 0.9>]()]]

完整追踪:

can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
('1.1- label batch shape is ', TensorShape([Dimension(128)]))
Inferencing
('in inferemcee ', TensorShape([Dimension(128), Dimension(3072)]), <class 'tensorflow.python.framework.ops.Tensor'>)
Evaluation..
tmp/ckpt/model.ckpt-9100
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789748be0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/string_input_producer/string_input_producer_EnqueueMany = QueueEnqueueMany[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/string_input_producer, input/string_input_producer/limit_epochs)]]
I tensorflow/core/kernels/fifo_queue.cc:154] Skipping cancelled enqueue attempt
Traceback (most recent call last):
  File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
    tf.app.run()
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78b939670 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954f080 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954e5d0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789550370 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78ba28cb0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
    sys.exit(main(sys.argv))
  File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
    my_eval()
  File "/ProjectS/Cifar-Eval/my_eval.py", line 85, in my_eval
    saver.restore(sess, ckpt.model_checkpoint_path)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 864, in restore
    sess.run([self._restore_op_name], {self._filename_tensor_name: save_path})
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 345, in run
    results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 419, in _do_run
    e.code)
tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
     [[Node: Reshape/shape = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [4] values: -1 32 32...>]()]]
Caused by op u'Reshape/shape', defined at:
  File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
    tf.app.run()
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
    sys.exit(main(sys.argv))
  File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
    my_eval()
  File "/ProjectS/Cifar-Eval/my_eval.py", line 78, in my_eval
    logits = my_cifar.inference(images_placeholder)
  File "/ProjectS/Cifar-Eval/my_cifar.py", line 68, in inference
    images = tf.reshape(images, shape=[-1, 32, 32, 3])
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 554, in reshape
    name=name)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 397, in apply_op
    values, name=input_arg.name, dtype=dtype)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 468, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 147, in constant
    attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1710, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 988, in __init__
    self._traceback = _extract_stack()

我的恢复检查点文件的代码

import tensorflow as tf

import my_cifar
import my_input

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('eval_dir', 'tmp/log_eval',
                           """Directory where to write event logs.""")

tf.app.flags.DEFINE_string('checkpoint_dir', 'tmp/ckpt',
                           """Directory where to read model checkpoints.""")


IMAGE_PIXELS = 32 * 32 * 3


def placeholder_inputs(batch_size):
  """Generate placeholder variables to represent the the input tensors.
  These placeholders are used as inputs by the rest of the model building
  code and will be fed from the downloaded ckpt in the .run() loop, below.
  Args:
    batch_size: The batch size will be baked into both placeholders.
  Returns:
    images_placeholder: Images placeholder.
    labels_placeholder: Labels placeholder.
  """
  # Note that the shapes of the placeholders match the shapes of the full
  # image and label tensors, except the first dimension is now batch_size
  # rather than the full size of the train or test ckpt sets.
  # batch_size = -1
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         IMAGE_PIXELS))
  # 32, 32, 3))
  labels_placeholder = tf.placeholder(tf.int32, shape=batch_size)

  return images_placeholder, labels_placeholder


def my_eval():
  with tf.Graph().as_default():

    v1 = tf.Variable(0)

    images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

    # Get images and labels for CIFAR-10.
    val_images, val_labels = my_input.inputs(False)

    init_op = tf.initialize_all_variables()

    coord = tf.train.Coordinator()

    with tf.Session() as sess:

      sess.run(init_op)

      saver = tf.train.Saver()
      # Start the queue runners.

      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      summary_op = tf.merge_all_summaries()
      summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
                                              graph_def=sess.graph_def)


      # Build a Graph that computes the logits predictions from the
      # inference model.
      logits = my_cifar.inference(images_placeholder)

      acc = my_cifar.evaluation(logits, labels_placeholder)

      ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
      print ckpt.model_checkpoint_path
      if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Restored!')

      images_val_r, labels_val_r = sess.run([val_images, val_labels])
      val_feed = {images_placeholder: images_val_r,
                  labels_placeholder: labels_val_r}

      tf.scalar_summary('Acc', acc)

      print('Calculating Acc  :')

      acc_r = sess.run(acc, feed_dict=val_feed)
      print(acc_r)

      # Write results to TensorBoard
      summary_str = sess.run(summary_op)
      summary_writer.add_summary(summary_str)


      coord.join(threads)


def main(argv=None):
  my_eval()


if __name__ == '__main__':
  tf.app.run()

1 个答案:

答案 0 :(得分:0)

您正在尝试加载原始网络中不存在的变量,我相信省略

    v1 = tf.Variable(0)

将解决问题。

如果要添加新变量,则需要以不同的方式加载它,加载方法应如下:

reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
restore_dict = dict()
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    if reader.has_tensor(tensor_name):
        print('has tensor ', tensor_name)
        restore_dict[tensor_name] = v
    # put the logic of the new/modified variable here and assign to the restore_dict, i.e. 
    # restore_dict['my_var_scope/my_var'] = get_my_variable()