在tensorflow中使用saver.restore()
方法时出现以下错误。知道为什么会这样吗?
我保存了这样的模型:
saver.save(sess, checkpoint_path, global_step=step)
错误是:
tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
[[Node: Variable_1/initial_value = Const[dtype=DT_FLOAT, value=Tensor<type: float shape: [] values: 0.9>]()]]
完整追踪:
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
('1.1- label batch shape is ', TensorShape([Dimension(128)]))
Inferencing
('in inferemcee ', TensorShape([Dimension(128), Dimension(3072)]), <class 'tensorflow.python.framework.ops.Tensor'>)
Evaluation..
tmp/ckpt/model.ckpt-9100
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789748be0 Compute status: Cancelled: Enqueue operation was cancelled
[[Node: input/string_input_producer/string_input_producer_EnqueueMany = QueueEnqueueMany[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/string_input_producer, input/string_input_producer/limit_epochs)]]
I tensorflow/core/kernels/fifo_queue.cc:154] Skipping cancelled enqueue attempt
Traceback (most recent call last):
File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
tf.app.run()
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78b939670 Compute status: Cancelled: Enqueue operation was cancelled
[[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954f080 Compute status: Cancelled: Enqueue operation was cancelled
[[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954e5d0 Compute status: Cancelled: Enqueue operation was cancelled
[[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789550370 Compute status: Cancelled: Enqueue operation was cancelled
[[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78ba28cb0 Compute status: Cancelled: Enqueue operation was cancelled
[[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
sys.exit(main(sys.argv))
File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
my_eval()
File "/ProjectS/Cifar-Eval/my_eval.py", line 85, in my_eval
saver.restore(sess, ckpt.model_checkpoint_path)
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 864, in restore
sess.run([self._restore_op_name], {self._filename_tensor_name: save_path})
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 345, in run
results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 419, in _do_run
e.code)
tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
[[Node: Reshape/shape = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [4] values: -1 32 32...>]()]]
Caused by op u'Reshape/shape', defined at:
File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
tf.app.run()
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
sys.exit(main(sys.argv))
File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
my_eval()
File "/ProjectS/Cifar-Eval/my_eval.py", line 78, in my_eval
logits = my_cifar.inference(images_placeholder)
File "/ProjectS/Cifar-Eval/my_cifar.py", line 68, in inference
images = tf.reshape(images, shape=[-1, 32, 32, 3])
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 554, in reshape
name=name)
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 397, in apply_op
values, name=input_arg.name, dtype=dtype)
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 468, in convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name)
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 147, in constant
attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1710, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 988, in __init__
self._traceback = _extract_stack()
我的恢复检查点文件的代码
import tensorflow as tf
import my_cifar
import my_input
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('eval_dir', 'tmp/log_eval',
"""Directory where to write event logs.""")
tf.app.flags.DEFINE_string('checkpoint_dir', 'tmp/ckpt',
"""Directory where to read model checkpoints.""")
IMAGE_PIXELS = 32 * 32 * 3
def placeholder_inputs(batch_size):
"""Generate placeholder variables to represent the the input tensors.
These placeholders are used as inputs by the rest of the model building
code and will be fed from the downloaded ckpt in the .run() loop, below.
Args:
batch_size: The batch size will be baked into both placeholders.
Returns:
images_placeholder: Images placeholder.
labels_placeholder: Labels placeholder.
"""
# Note that the shapes of the placeholders match the shapes of the full
# image and label tensors, except the first dimension is now batch_size
# rather than the full size of the train or test ckpt sets.
# batch_size = -1
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
IMAGE_PIXELS))
# 32, 32, 3))
labels_placeholder = tf.placeholder(tf.int32, shape=batch_size)
return images_placeholder, labels_placeholder
def my_eval():
with tf.Graph().as_default():
v1 = tf.Variable(0)
images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)
# Get images and labels for CIFAR-10.
val_images, val_labels = my_input.inputs(False)
init_op = tf.initialize_all_variables()
coord = tf.train.Coordinator()
with tf.Session() as sess:
sess.run(init_op)
saver = tf.train.Saver()
# Start the queue runners.
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
graph_def=sess.graph_def)
# Build a Graph that computes the logits predictions from the
# inference model.
logits = my_cifar.inference(images_placeholder)
acc = my_cifar.evaluation(logits, labels_placeholder)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
print ckpt.model_checkpoint_path
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print('Restored!')
images_val_r, labels_val_r = sess.run([val_images, val_labels])
val_feed = {images_placeholder: images_val_r,
labels_placeholder: labels_val_r}
tf.scalar_summary('Acc', acc)
print('Calculating Acc :')
acc_r = sess.run(acc, feed_dict=val_feed)
print(acc_r)
# Write results to TensorBoard
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str)
coord.join(threads)
def main(argv=None):
my_eval()
if __name__ == '__main__':
tf.app.run()
答案 0 :(得分:0)
您正在尝试加载原始网络中不存在的变量,我相信省略
v1 = tf.Variable(0)
将解决问题。
如果要添加新变量,则需要以不同的方式加载它,加载方法应如下:
reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v
# put the logic of the new/modified variable here and assign to the restore_dict, i.e.
# restore_dict['my_var_scope/my_var'] = get_my_variable()