如何在golang中对文本执行DL-RNN模型?

时间:2017-11-06 17:58:22

标签: go tensorflow nlp rnn

我已根据reddit / twitter对话在RNN中构建tensor-flow模型。我将其保存在pb中。有谁知道如何通过golang中的模型传递原始文本字符串并生成输出?

modeldir := "/my_model.pb"

// Buffer input text
var buffer bytes.Buffer

args := os.Args[1:]

for _, arg := range args {
    buffer.WriteString(arg + " ")
}

inputText := buffer.String()

// Load the serialized GraphDef from a file.

model, err := ioutil.ReadFile(modeldir)
if err != nil {
    log.Fatal(err)
}
// Construct an in-memory graph from the serialized form.
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
    log.Fatal(err)
}
// Create a session for inference over graph.
session, err := tf.NewSession(graph, nil)
if err != nil {
    log.Fatal(err)
}
defer session.Close()

1 个答案:

答案 0 :(得分:1)

您可以使用tfgo轻松加载到Go并使用训练有素的张量流模型:只需使用tf.saved_model.builder.SavedModelBuilder导出训练后的模型,如tfgo自述文件所示。

但是,您只需从图表中提取输入占位符,然后使用它来提供网络。

假设您导出模型调用它my_model并使用标记tag对其进行标记。另外,我们假设您的输入占位符名为" Placeholder"。此外,您必须知道输出节点的名称。我们称之为output/node/path/op。然后您的代码应如下所示:

import (
        "fmt"
        tg "github.com/galeone/tfgo"
        tf "github.com/tensorflow/tensorflow/tensorflow/go"
        "flags"
)

func main() {
        model := tg.LoadModel("my_model", []string{"tag"}, nil)

        // Buffer input text
        var buffer bytes.Buffer
        args := os.Args[1:]

        for _, arg := range args {
            buffer.WriteString(arg + " ")
        }
        // handle the retunred error below, if any
        inputText, _ := tf.NewTensor(buffer.String())

        results := model.Exec([]tf.Output{
                model.Op("output/node/path/op", 0),
        }, map[tf.Output]*tf.Tensor{
                model.Op("Placeholder", 0): inputText,
        })
        // do something with results[0]
}