我已根据reddit / twitter对话在RNN
中构建tensor-flow
模型。我将其保存在pb
中。有谁知道如何通过golang
中的模型传递原始文本字符串并生成输出?
modeldir := "/my_model.pb"
// Buffer input text
var buffer bytes.Buffer
args := os.Args[1:]
for _, arg := range args {
buffer.WriteString(arg + " ")
}
inputText := buffer.String()
// Load the serialized GraphDef from a file.
model, err := ioutil.ReadFile(modeldir)
if err != nil {
log.Fatal(err)
}
// Construct an in-memory graph from the serialized form.
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
log.Fatal(err)
}
// Create a session for inference over graph.
session, err := tf.NewSession(graph, nil)
if err != nil {
log.Fatal(err)
}
defer session.Close()
答案 0 :(得分:1)
您可以使用tfgo轻松加载到Go并使用训练有素的张量流模型:只需使用tf.saved_model.builder.SavedModelBuilder
导出训练后的模型,如tfgo
自述文件所示。
但是,您只需从图表中提取输入占位符,然后使用它来提供网络。
假设您导出模型调用它my_model
并使用标记tag
对其进行标记。另外,我们假设您的输入占位符名为" Placeholder"。此外,您必须知道输出节点的名称。我们称之为output/node/path/op
。然后您的代码应如下所示:
import (
"fmt"
tg "github.com/galeone/tfgo"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
"flags"
)
func main() {
model := tg.LoadModel("my_model", []string{"tag"}, nil)
// Buffer input text
var buffer bytes.Buffer
args := os.Args[1:]
for _, arg := range args {
buffer.WriteString(arg + " ")
}
// handle the retunred error below, if any
inputText, _ := tf.NewTensor(buffer.String())
results := model.Exec([]tf.Output{
model.Op("output/node/path/op", 0),
}, map[tf.Output]*tf.Tensor{
model.Op("Placeholder", 0): inputText,
})
// do something with results[0]
}