我正在测试在Keras中构建模型,然后将其转换为某种Tensorflow格式,以便我可以在Tensorflow C ++ API中运行预测。我正在适应这个tutorial。我正在测试MNIST数据集,并在Keras中构建了我的模型:
inpt = keras.layers.Input(shape = (28,28,1), name = "input_node")
x = keras.layers.Convolution2D(16, 2, padding = 'same', activation = 'relu')(inpt)
x = keras.layers.MaxPool2D(pool_size = 2)(x)
x = keras.layers.Convolution2D(32, 2, padding = 'same', activation = 'relu')(x)
x = keras.layers.MaxPool2D(pool_size = 2)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(128, activation = 'relu')(x)
output = keras.layers.Dense(10, activation = 'softmax', name = "output_node")(x)
model = keras.models.Model(inpt,output)
model.compile(optimizer = keras.optimizers.Adam(lr = 0.0001), loss = 'categorical_crossentropy', metrics = ['accuracy'])
然后使用model_to_estimator
函数:
estimator_model = tf.keras.estimator.model_to_estimator(keras_model = model, model_dir = './TF_MNIST')
效果很好,我可以使用以下方式进行训练:
estimator_model.train(input_fn = input_function(X_train,y_train,True))
但是,我尝试使用freeze_graph如下:
checkpoint_state_name = "model.ckpt-21001.index"
input_graph_name = "graph.pbtxt"
output_graph_name = "output_graph.pb"
input_graph_path = os.path.join('./TF_MNIST', input_graph_name)
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = os.path.join('./TF_MNIST', checkpoint_state_name)
output_node_names = "output_node"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join('./TF_MNIST', output_graph_name)
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, input_checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, initializer_nodes = "input_node")
我选择了output_graph.pb
作为生成的freeze_graph的目的地。
我收到以下错误:
ValueError Traceback (most recent call last)
<ipython-input-69-215edbaaf017> in <module>()
3 output_node_names, restore_op_name,
4 filename_tensor_name, output_graph_path,
----> 5 clear_devices, initializer_nodes = "input_node")
ValueError: No variables to save
在本教程的示例中,没有输入参数initializer_nodes
所以我假设它是输入节点的名称。此外,当我使用不是.index
文件的检查点文件时,它会发出Data loss
警告,说明数据格式不正确。
.index
检查点文件是正确的(如果确实是正确的话)?input_graph.pb
图表,而我的是.pbtxt
,为什么会这样?tf.Session()
引入我的Keras模型以存储和打印准确度分数,因为目前这些分数不是在训练中打印的,也不是存储在TensorBoard正在读取的检查点文件中。非常感谢任何这些问题的帮助。