Tensorflow:使用占位符

时间:2018-04-06 12:28:44

标签: python tensorflow machine-learning

如何替换从文件加载的GraphDef中的占位符以将导入的图形连接到数据集提供程序?

此脚本大量借用eval_image_classifier.py脚本作为slim API的一部分。

首先我打开图表

with tf.Graph().as_default():

然后我使用slim API

设置数据集提供程序和预处理功能
  # Select the dataset
  # Create a dataset provider that loads data from the dataset
  # Select the preprocessing function
  ...
  image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

  images, labels = tf.train.batch(
      [image, label],
      batch_size=batch_size,
      num_threads=num_preprocessing_threads,
      capacity=5 * batch_size)

然后我从GraphDef导入一个图表,并使用import_graph_def

将其加载到当前图表中
  quantized_graph_def = graph_pb2.GraphDef()
  with tf.gfile.FastGFile(path.join(cwd(), quantized_graph_filename), 'rb') as f:
    quantized_graph_def.ParseFromString(f.read())
  tf.import_graph_def(quantized_graph_def, input_map={'batch': images}, name='')

然后我设置指标并调用slim.evaluation.evaluate_once来处理批次

  # Define the metrics:
  names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
      'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
      'Recall_5': slim.metrics.streaming_recall_at_k(
          logits, labels, 5),
  })

  ...

  slim.evaluation.evaluate_once(
      master=master,
      checkpoint_path=checkpoint_path,
      logdir=log_dir,
      num_evals=num_batches,
      eval_op=list(names_to_updates.values()))

当我运行时,我收到以下错误:

Caused by op 'batch_1', defined at:
  File "vanilla_vgg.py", line 319, in <module>
    import_quantized_graph_with_imagenet()
  File "vanilla_vgg.py", line 251, in import_quantized_graph_with_imagenet
    tf.import_graph_def(quantized_graph_def, input_map={'batch': images}, name='')
  File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
    return func(*args, **kwargs)
  File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 553, in import_graph_def
    op_def=op_def)
  File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3271, in create_op
    op_def=op_def)
  File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1650, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'batch_1' with dtype float and shape [100,224,224,3]
         [[Node: batch_1 = Placeholder[dtype=DT_FLOAT, shape=[100,224,224,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

我正在加载的GraphDef有一个名为batch的占位符操作符,其形状和dtype与张量images相同。作为参考,运行print(images)会返回:

Tensor("batch:0", shape=(100, 224, 224, 3), dtype=float32)

请注意,我已将input_map参数提供给import_graph_def函数,该函数应使用张量batch替换images占位符。我也尝试使用batch:0batch_1作为input_map的关键,但都不起作用。

根据tf.import_graph_def的文档:

  

input_map:将graph_def中的输入名称(作为字符串)映射到Tensor对象的字典。导入图中命名输入张量的值将重新映射到相应的Tensor值。

据我所知,input_map参数应该连接两个图,但这似乎并没有起作用。请参阅相关文章"Connecting Two Graphs Together using import_graph_def"。我相信我的做法和文章一样。

此外,evaluate_once是一个在单个函数调用中运行一批图像的函数,因此我不能简单地调用images.eval()并将结果传递给evaluate_once,因为它只会运行第一批。因此,两个图必须连接,并且能够通过单个调用运行。

0 个答案:

没有答案