如何替换从文件加载的GraphDef中的占位符以将导入的图形连接到数据集提供程序?
此脚本大量借用eval_image_classifier.py
脚本作为slim
API的一部分。
首先我打开图表
with tf.Graph().as_default():
然后我使用slim
API
# Select the dataset
# Create a dataset provider that loads data from the dataset
# Select the preprocessing function
...
image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocessing_threads,
capacity=5 * batch_size)
然后我从GraphDef导入一个图表,并使用import_graph_def
quantized_graph_def = graph_pb2.GraphDef()
with tf.gfile.FastGFile(path.join(cwd(), quantized_graph_filename), 'rb') as f:
quantized_graph_def.ParseFromString(f.read())
tf.import_graph_def(quantized_graph_def, input_map={'batch': images}, name='')
然后我设置指标并调用slim.evaluation.evaluate_once
来处理批次
# Define the metrics:
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
'Recall_5': slim.metrics.streaming_recall_at_k(
logits, labels, 5),
})
...
slim.evaluation.evaluate_once(
master=master,
checkpoint_path=checkpoint_path,
logdir=log_dir,
num_evals=num_batches,
eval_op=list(names_to_updates.values()))
当我运行时,我收到以下错误:
Caused by op 'batch_1', defined at:
File "vanilla_vgg.py", line 319, in <module>
import_quantized_graph_with_imagenet()
File "vanilla_vgg.py", line 251, in import_quantized_graph_with_imagenet
tf.import_graph_def(quantized_graph_def, input_map={'batch': images}, name='')
File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
return func(*args, **kwargs)
File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 553, in import_graph_def
op_def=op_def)
File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3271, in create_op
op_def=op_def)
File "/localtmp/mp3t/venv/doggett/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1650, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'batch_1' with dtype float and shape [100,224,224,3]
[[Node: batch_1 = Placeholder[dtype=DT_FLOAT, shape=[100,224,224,3], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
我正在加载的GraphDef有一个名为batch
的占位符操作符,其形状和dtype与张量images
相同。作为参考,运行print(images)
会返回:
Tensor("batch:0", shape=(100, 224, 224, 3), dtype=float32)
请注意,我已将input_map
参数提供给import_graph_def
函数,该函数应使用张量batch
替换images
占位符。我也尝试使用batch:0
和batch_1
作为input_map
的关键,但都不起作用。
根据tf.import_graph_def
的文档:
input_map
:将graph_def中的输入名称(作为字符串)映射到Tensor对象的字典。导入图中命名输入张量的值将重新映射到相应的Tensor值。
据我所知,input_map
参数应该连接两个图,但这似乎并没有起作用。请参阅相关文章"Connecting Two Graphs Together using import_graph_def
"。我相信我的做法和文章一样。
此外,evaluate_once
是一个在单个函数调用中运行一批图像的函数,因此我不能简单地调用images.eval()
并将结果传递给evaluate_once
,因为它只会运行第一批。因此,两个图必须连接,并且能够通过单个调用运行。