我已经构建了一个使用Dataset API训练的自动编码器。该张量架构描述了该体系结构:
我想在其他学习任务中仅重用编码器部分,所以我尝试使用
冻结图形g = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["AEC/encoded"])
g = tf.graph_util.extract_sub_graph(g, ["AEC/encoded"])
g = tf.graph_util.remove_training_nodes(g, protected_nodes=["AEC/input", "AEC/encoded"])
with open(str(Path(params.encoder_export_dir)/"encoder.pb"), "wb") as f:
f.write(g.SerializeToString())
然后尝试使用
在我的其他代码中导入它encoder_input = tf.placeholder(tf.float32, [None, 2049])
gd = tf.GraphDef()
with open('./path/to/encoder.pb', 'rb') as f:
gd.ParseFromString(f.read())
[out] = tf.import_graph_def(gd,
input_map={"AEC/input" : encoder_input},
return_elements=['AEC/encoded'],
name=''
)
但运行out
张量并在encoder_input
中提供内容时,我得到None
我试图在tensorboard中可视化导出的图形
似乎张量的形状消失了。
所以我的问题是如何以一种允许我在另一段代码中将其用作“黑匣子”的方式导出我的编码器?
编辑:
我使用占位符而不是数据集迭代器get_next张量来实现我的模型,除了输入节点(对应于占位符)将其形状存储在其属性中之外,缺少维度保持不变。
编辑2:
按照this issue report中的建议,我在使用
导出图表时添加了形状信息g = tf.get_default_graph().as_graph_def(add_shapes=True)
现在看到有关tensorboard架构的形状信息,但计算仍然返回None
答案 0 :(得分:0)
最后问题是由于语法错误:而不是使用
[out] = tf.import_graph_def(gd,
input_map={"AEC/input" : encoder_input},
return_elements=['AEC/encoded'],
name=''
)
AEC/encoded
中的return_elements
是一个操作,正确的方法是使用AEC/encoded:0