我正在使用Tensorflow Java Api将已创建的Tensorflow模型加载到JVM中。 我以此为例:tensorflow/examples/LabelImage.java
这是我的简单scala代码:
import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}
def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))
如何保存模型以使Session和Graph存储在同一个文件中。如" PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"上方。
描述here它提到:
图表的序列化表示,通常称为a GraphDef,可以由toGraphDef()和其他的等价物生成 语言API。
其他语言API的等价物是什么?我发现它不明显
注意:我已经在tensorflow_serving下查看了mnist_saved_model.py,但通过该过程保存它会给我一个.pb
文件和一个variables
文件夹。在尝试加载.pb
文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef
答案 0 :(得分:1)
目前使用tensorflow的Java API,我只发现了如何将图形保存为graphDef(即没有其变量和元数据)。这可以通过将Array [Byte]写入文件来完成:
-ssh
此处Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)
是Graph class中的java对象。
我建议使用此处定义的SavedModel api从Python API保存模型。它会将模型保存在一个文件夹中,该文件夹包含.pb文件中的序列化图形和文件夹中的变量。请注意您使用的tag_constants,因为您需要在scala / java代码中使用变量加载模型。然后使用java api中的SavedModelBundle java类轻松加载带变量的图形和会话。它返回一个包含图形和包含变量值的会话的包装器:
myGraph
如果你已经尝试过这个,也许你可以分享你的代码,看看为什么它返回了一个无效的GraphDef。
另一种选择是冻结图形,即您将变量节点变为常量节点,因此.pb文件中的所有内容都是自包含的。 Mores infos here用于冻结部分