在Tensorflow for Java中加载预训练的模型

时间:2018-07-19 17:01:36

标签: java tensorflow pre-trained-model

我正在尝试使用Java API在Tensorflow中加载pre-trained models

我注意到,随着时间的流逝,已保存的模型文件的格式已更改,现在已保存的模型的文件格式为.pb .ckpt,模型目录的文件名为model.ckpt.data-00000-of-00001 , model.ckpt.index

我正在遵循读取LabelImage example中指定的模型的方法。但是在此示例中,文件格式为protobuf .pb。我看到最新保存的模型以.ckptmodel.ckpt.data-00000-of-00001 , model.ckpt.index格式保存。

我尝试对包含文件export_dirmodel.ckpt.data-00000-of-00001的{​​{1}}使用SavedModelBundle方法,但出现此错误

model.ckpt.index

`2018-07-18 16:54:00.388790: I tensorflow/cc/saved_model/loader.cc:291] SavedModel load for tags { }; Status: fail. Took 95 microseconds.

有人可以告诉我我做错了什么吗,或者让我知道如何读取以Java中Exception in thread "main" org.tensorflow.TensorFlowException: SavedModel not found in export directory: /path/to/model_dir at org.tensorflow.SavedModelBundle.load(Native Method) at org.tensorflow.SavedModelBundle.load(SavedModelBundle.java:39)以外的文件格式保存的已保存模型。

1 个答案:

答案 0 :(得分:0)

我认为您可以尝试通过两种方式解决问题:

  1. 将保存的模型(检查点文件)的格式转换为protobuf文件

将保存的模型还原到当前会话后:sess,

# Freeze the graph, with output _node_names is the name of the output when construct the model  
# Eg. output_node_names = ["prediction"]
frozen_graph_def = tf.graph_util.convert_variables_to_constants (sess, sess.graph_def, output_node_names)

# Save the frozen graph  
with open (frozen_graph_file, "wb") as f:    
   f.write(frozen_graph_def.SerializeToString()

它应该将以前的格式转换为新的格式。

  1. 重新训练模型并将其保存为.pb格式。