使用带有Keras模型的freeze_graph.py转换为估算器时出错

时间:2018-06-06 11:39:11

标签: python tensorflow keras

我正在测试在Keras中构建模型,然后将其转换为某种Tensorflow格式,以便我可以在Tensorflow C ++ API中运行预测。我正在适应这个tutorial。我正在测试MNIST数据集,并在Keras中构建了我的模型:

inpt = keras.layers.Input(shape = (28,28,1), name = "input_node")
x = keras.layers.Convolution2D(16, 2, padding = 'same', activation = 'relu')(inpt)
x = keras.layers.MaxPool2D(pool_size = 2)(x)
x = keras.layers.Convolution2D(32, 2, padding = 'same', activation = 'relu')(x)
x = keras.layers.MaxPool2D(pool_size = 2)(x)

x = keras.layers.Flatten()(x)

x = keras.layers.Dense(128, activation = 'relu')(x)

output = keras.layers.Dense(10, activation  = 'softmax', name = "output_node")(x)

model = keras.models.Model(inpt,output)

model.compile(optimizer = keras.optimizers.Adam(lr = 0.0001), loss = 'categorical_crossentropy', metrics = ['accuracy'])

然后使用model_to_estimator函数:

estimator_model = tf.keras.estimator.model_to_estimator(keras_model = model, model_dir = './TF_MNIST')

效果很好,我可以使用以下方式进行训练:

estimator_model.train(input_fn = input_function(X_train,y_train,True))

但是,我尝试使用freeze_graph如下:

checkpoint_state_name = "model.ckpt-21001.index"
input_graph_name = "graph.pbtxt"
output_graph_name = "output_graph.pb"

input_graph_path = os.path.join('./TF_MNIST', input_graph_name)
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = os.path.join('./TF_MNIST', checkpoint_state_name)

output_node_names = "output_node" 
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join('./TF_MNIST', output_graph_name)
clear_devices = False

freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                      input_binary, input_checkpoint_path,
                      output_node_names, restore_op_name,
                      filename_tensor_name, output_graph_path,
                      clear_devices, initializer_nodes = "input_node")

我选择了output_graph.pb作为生成的freeze_graph的目的地。

我收到以下错误:

ValueError  Traceback (most recent call last)
<ipython-input-69-215edbaaf017> in <module>()
      3  output_node_names, restore_op_name,
      4  filename_tensor_name, output_graph_path,
----> 5  clear_devices, initializer_nodes = "input_node")

ValueError: No variables to save

在本教程的示例中,没有输入参数initializer_nodes所以我假设它是输入节点的名称。此外,当我使用不是.index文件的检查点文件时,它会发出Data loss警告,说明数据格式不正确。

问题:

  1. 如何解决此错误?
  2. 为什么.index检查点文件是正确的(如果确实是正确的话)?
  3. 本教程有一个input_graph.pb图表,而我的是.pbtxt,为什么会这样?
  4. 我可以将tf.Session()引入我的Keras模型以存储和打印准确度分数,因为目前这些分数不是在训练中打印的,也不是存储在TensorBoard正在读取的检查点文件中。
  5. 非常感谢任何这些问题的帮助。

0 个答案:

没有答案