将Keras K.function()冻结为张量流图

时间:2017-07-04 10:33:41

标签: python tensorflow deep-learning keras

我有一个在Keras训练的模型,我使用K.function()使用其中间层输出。有没有办法将此K.function()对象保存为张量流图?我想在tensorflow服务中使用此对象,但我没有看到冻结Keras K.function()对象的方法

1 个答案:

答案 0 :(得分:0)

我最终想通了,我想我应该在这里发布,以防将来有人需要这个。我必须从K.function()对象获取节点名称。

encoder = K.function([...])  # define your Function object
encoder_tensor_names = [t.name for t in encoder.outputs]
encoder_node_names = [tn.replace(':0', '') for tn in encoder_tensor_names]  # node names are tensor names without :0
graph_def = tf.graph_util.convert_variables_to_constants(
    K.get_session(), 
    K.get_session().graph.as_graph_def(), 
    encoder_node_names 
)

然后按照https://stackoverflow.com/a/44044405/5453184中的代码写出二进制(或文本)文件。它的优点在于张量流足够智能,只能保留前馈操作所需的图形部分,在我的情况下,它也会冻结字嵌入,因此我不再需要单独的字矢量文件。