使用keras训练TF模型,然后在Go中进行评估

时间:2017-09-22 14:20:15

标签: python go tensorflow keras

我正在尝试使用keras设置经典的MNIST质询模型,然后保存tensorflow图并随后将其加载到Go中,并使用一些输入进行评估。我一直关注this articlegithub提供完整代码Formula。 Nils只使用tensorflow来设置comp.graph,但我想使用keras。我管理以与他一样的方式保存模型

模型:

   model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=(28,28,1), name="inputNode"))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax', name="inferNode"))

运行正常,训练和评估,然后保存如上所述:

builder = tf.saved_model.builder.SavedModelBuilder("mnistmodel_my")
# GOLANG note that we must tag our model so that we can retrieve it at inference-time
builder.add_meta_graph_and_variables(sess, ["serve"])
builder.save()

然后我尝试评估为:

result, runErr := model.Session.Run(
        map[tf.Output]*tf.Tensor{
            model.Graph.Operation("inputNode").Output(0): tensor,
        },
        []tf.Output{
            model.Graph.Operation("inferNode").Output(0),
        },
        nil,
    )

在Go中,我按照示例进行操作,但在评估时,我得到:

    panic: nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.

goroutine 1 [running]:
github.com/tensorflow/tensorflow/tensorflow/go.Output.c(0x0, 0x0, 0x0, 0x0)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/operation.go:119 +0xbb
github.com/tensorflow/tensorflow/tensorflow/go.newCRunArgs(0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xc4200723c8)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:307 +0x22d
github.com/tensorflow/tensorflow/tensorflow/go.(*Session).Run(0xc420078060, 0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, ...)
    /Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:85 +0x153
main.main()
    /Users/air/PycharmProjects/GoTensor/custom.go:36 +0x341
exit status 2

因为它说nil-Operation我认为我可能错误地标记了节点。但我不知道我应该标记哪些其他节点?

非常感谢!!!

2 个答案:

答案 0 :(得分:4)

您的代码应该可以正常运行。你对零操作的原因是正确的。

您只需要找到“inputNode”的完整节点名称。

从python中,在模型定义之后,您可以循环遍历图形节点并以这种方式查找完整名称:

for n in sess.graph.as_graph_def().node:
    if "inputNode" in n.name:
        print(n.name)

获得完整名称后,您可以在Go程序中使用它。

另外,我建议您在tensorflow API周围使用更完整且易于使用的包装器:tfgo

答案 1 :(得分:0)

以显示session.graph列表中的所有项目(在Golang中):

ops := model.Graph.Operations()
for _, op := range ops {
    fmt.Println(op.Name())
}