vgg16 keras保存导出c ++的模型

时间:2018-05-17 16:51:18

标签: python c++ keras deep-learning

我最近在处理vgg16以重用其模型。

蟒蛇中的

:( Keras)

model = applications.VGG16(include_top=False, weights='imagenet')

一切都很好。

我需要使用compile和fit导出THiS模型来导出c ++ json文件。 如何正确导出ThiS模型h5文件以用于导出模型?

1 个答案:

答案 0 :(得分:2)

一种方法是将您的Keras模型转换为Python中的TensorFlow模型,然后将冻结的图形导出到.pb文件。然后在C ++中加载它。我使用这段代码从Keras导出冻结的.pb文件。

import tensorflow as tf

from keras import backend as K
from tensorflow.python.framework import graph_util

K.set_learning_phase(0)
model = function_that_returns_your_keras_model()
sess = K.get_session()

output_node_name = "my_output_node" # Name of your output node

with sess as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    graph_def = sess.graph.as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
                                                                 sess,
                                                                 sess.graph.as_graph_def(),
                                                                 output_node_name.split(","))
    tf.train.write_graph(output_graph_def,
                         logdir="my_dir",
                         name="my_model.pb",
                         as_text=False)

然后,您可以按照任何有关如何在C ++中加载.pb文件的教程进行操作。举个例子:https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f

Keras在TensorFlow图中注入learning_phase变量,也可能注入其他仅Keras变量 - 如果我没记错的话,你应该确保从图中删除它们。