在Golang上进行Tensorflow部分运行(RNN状态)

时间:2019-05-21 08:51:35

标签: tensorflow go deep-learning recurrent-neural-network

我有一个GRU RNN文本生成模型,我将其作为protobuf导入了Golang。

model, err := tf.LoadSavedModel("poetryModel", []string{"goTag"}, nil)

类似于the code from this Tensorflow tutorial,我正在运行一个预测循环:

for len(generated_text) < 1000 {
    result, err := model.Session.Run(
            map[tf.Output]*tf.Tensor{
                model.Graph.Operation("inputLayer_input").Output(0): tensor,
            },
            []tf.Output{
                model.Graph.Operation("outputLayer/add").Output(0),
            },
            nil,
        )
    ...}

但是,此实现在每次循环后都丢弃所有中间状态,这会导致生成错误的文本。 我尝试使用Partial Run,但在第二次运行时却抛出错误:

pr, err := model.Session.NewPartialRun(
    []tf.Output{ model.Graph.Operation("inputLayer_input").Output(0), },
    []tf.Output{ model.Graph.Operation("outputLayer/add").Output(0), },
    []*tf.Operation{ model.Graph.Operation("outputLayer/add") },
)
if err != nil {
    panic(err)
}

...

result, err := pr.Run(
        map[tf.Output]*tf.Tensor{
            model.Graph.Operation("inputLayer_input").Output(0): tensor,
        },
        []tf.Output{
            model.Graph.Operation("outputLayer/add").Output(0),
        },
        nil,
    )

Error running the session with input, err: Must run 'setup' before performing partial runs!

This question与此类似,但在Python中。另外,Go中没有关于设置功能的文档。 我是直接使用TF计算图和Golang的新手,所以能提供任何帮助。

0 个答案:

没有答案