从元数据文件

时间:2017-02-17 20:06:35

标签: tensorflow

我已经训练了一个DCGAN模型,现在想把它加载到一个库中,通过图像空间优化可视化神经元激活的驱动因素。

以下代码有效,但在进行后续图像分析时迫使我使用(1,宽度,高度,通道)图像,这很痛苦(图书馆对网络输入形状的假设)。

# creating TensorFlow session and loading the model
graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

new_saver = tf.train.import_meta_graph(model_fn)
new_saver.restore(sess, './')

我想更改input_map,在阅读完源代码后,我希望这段代码能够正常工作:

graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

t_input = tf.placeholder(np.float32, name='images') # define the input tensor
t_preprocessed = tf.expand_dims(t_input, 0)

new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input})
new_saver.restore(sess, './')

但得到了一个错误:

  

ValueError:如果使用name,则tf.import_graph_def()需要非空input_map

当堆栈降至tf.import_graph_def()时,name字段设置为import_scope,因此我尝试了以下操作:

graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

t_input = tf.placeholder(np.float32, name='images') # define the input tensor
t_preprocessed = tf.expand_dims(t_input, 0)

new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input}, import_scope='import')
new_saver.restore(sess, './')

让我获得了以下KeyError

  

KeyError:“名称'gradients / discriminator / minibatch / map / while / TensorArrayWrite / TensorArrayWriteV3_grad / TensorArrayReadV3 / RefEnter:0'指的是一个不存在的Tensor。操作',gradients / discriminator / minibatch / map /而/ TensorArrayWrite / TensorArrayWriteV3_grad / TensorArrayReadV3 / RefEnter',在图表中不存在。“

如果我设置'import_scope',无论是否设置'input_map',我都会得到相同的错误。

我不确定从哪里开始。

2 个答案:

答案 0 :(得分:2)

在较新版本的tensorflow> = 1.2.0中,以下步骤正常。

t_input = tf.placeholder(np.float32, shape=[None, width, height, channels], name='new_input') # define the input tensor

# here you need to give the name of the original model input placeholder name
# For example if the model has input as; input_original=  tf.placeholder(tf.float32, shape=(1, width, height, channels, name='original_placeholder_name'))
new_saver = tf.train.import_meta_graph(/path/to/checkpoint_file.meta, input_map={'original_placeholder_name:0':  t_input})
new_saver.restore(sess, '/path/to/checkpointfile')

答案 1 :(得分:0)

因此,主要问题是您没有正确使用语法。查看2的文档,了解tf.import_graph_deflink)的使用情况。

让我们分析这一行:

input_map

您没有概述new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input}, import_scope='import') 是什么,但它需要是文件的路径。 对于下一部分,在model_fn中,您说:替换 原始图(DCGAN)中input_map为{{{}的输入1}}我的变量(在当前图表中)称为name。有问题的是,imagest_input以不同的方式引用同一个对象:

t_input

换句话说,images中的 t_input = tf.placeholder(np.float32, name='images') 实际上应该是您尝试在DCGAN图中替换的变量名称。您必须以其基本形式导入图形(即,没有images行)并找出要链接到的变量的名称。导入图表后,它将出现在input_map返回的列表中。查找尺寸(1,宽度,高度,通道),但使用值代替变量名称。如果它是占位符,它看起来像input_map,其中tf.get_collection('variables')被替换为变量范围的任何内容。

提醒:

Tensorflow对于它所期望的图形看起来非常挑剔。因此,如果在原始图形规范中明确指定了宽度,高度和通道,那么当您尝试使用不同的维度集连接scope/Placeholder:0时,Tensorflow会抱怨(抛出错误)。而且,这是有道理的。如果系统是使用一组维度进行训练的,那么它只知道如何生成具有这些维度的图像。

理论上,你仍然可以在网络的前面贴上各种奇怪的东西。但是,您需要将其缩小以使其首先满足这些维度(并且Tensorflow文档表示最好使用图形外部的CPU执行此操作;即,在使用scope输入之前)。

希望有所帮助!