什么是Tensorflow Java Api`toGraphDef`在Python中的等价物?

时间:2017-04-05 22:46:21

标签: java scala tensorflow java-native-interface tensorflow-serving

我正在使用Tensorflow Java Api将已创建的Tensorflow模型加载到JVM中。 我以此为例:tensorflow/examples/LabelImage.java

这是我的简单scala代码:

import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}

def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))

如何保存模型以使Session和Graph存储在同一个文件中。如" PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"上方。

描述here它提到:

  

图表的序列化表示,通常称为a   GraphDef,可以由toGraphDef()和其他的等价物生成   语言API。

其他语言API的等价物是什么?我发现它不明显

注意:我已经在tensorflow_serving下查看了mnist_saved_model.py,但通过该过程保存它会给我一个.pb文件和一个variables文件夹。在尝试加载.pb文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef

1 个答案:

答案 0 :(得分:1)

目前使用tensorflow的Java API,我只发现了如何将图形保存为graphDef(即没有其变量和元数据)。这可以通过将Array [Byte]写入文件来完成:

-ssh

此处Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef) Graph class中的java对象。

我建议使用此处定义的SavedModel api从Python API保存模型。它会将模型保存在一个文件夹中,该文件夹包含.pb文件中的序列化图形和文件夹中的变量。请注意您使用的tag_constants,因为您需要在scala / java代码中使用变量加载模型。然后使用java api中的SavedModelBundle java类轻松加载带变量的图形和会话。它返回一个包含图形和包含变量值的会话的包装器:

myGraph

如果你已经尝试过这个,也许你可以分享你的代码,看看为什么它返回了一个无效的GraphDef。

另一种选择是冻结图形,即您将变量节点变为常量节点,因此.pb文件中的所有内容都是自包含的。 Mores infos here用于冻结部分