如何使用golang将shape = [?]的输入字符串提供给tensorflow模型

时间:2017-09-23 09:32:00

标签: go tensorflow tensorflow-serving

火车模型的Python代码:

input_schema = dataset_schema.from_feature_spec({
    REVIEW_COLUMN: tf.FixedLenFeature(shape=[], dtype=tf.string),
    LABEL_COLUMN: tf.FixedLenFeature(shape=[], dtype=tf.int64)
})

在python预测中工作正常。客户端示例:

loaded_model = tf.saved_model.loader.load(sess, ["serve"], '/tmp/model/export/Servo/1506084916')
input_dict, output_dict =_signature_def_to_tensors(loaded_model.signature_def['default_input_alternative:None'])
start = datetime.datetime.now()
out = sess.run(output_dict, feed_dict={input_dict["inputs"]: ("I went and saw this movie last night",)})
print(out)
print("Time all: ", datetime.datetime.now() - start)

但golang客户端不起作用:

m, err := tf.LoadSavedModel("/tmp/model/export/Servo/1506084916", []string{"serve"}, &tf.SessionOptions{})
if err != nil {
    panic(fmt.Errorf("load model: %s", err))
}

data := "I went and saw this movie last night"
t, err := tf.NewTensor([]string{data})
if err != nil {
    panic(fmt.Errorf("tensor err: %s", err))
}
fmt.Printf("tensor: %v", t.Shape())

output, err = m.Session.Run(
    map[tf.Output]*tf.Tensor{
        m.Graph.Operation("save_1/StringJoin/inputs_1").Output(0): t,
    }, []tf.Output{
        m.Graph.Operation("linear/binary_logistic_head/predictions/classes").Output(0),
    }, nil,
)
if err != nil {
    panic(fmt.Errorf("run model: %s", err))
}

我收到了错误:

  

恐慌:运行模型:您必须为占位符张量提供值   '占位符'用dtype字符串和形状[?]            [[Node:Placeholder = Placeholder_output_shapes = [[?]],dtype = DT_STRING,shape = [?],   _device =" /作业:本地主机/复制:0 /任务:0 / CPU:0"]]

如何使用golang呈现shape=[?]张量?或者我需要更改python训练脚本的输入格式?

UPD:

运行此python代码后收到的这个字符串"save_1/StringJoin/inputs_1"

for n in sess.graph.as_graph_def().node:
    if "inputs" in n.name:
        print(n.name)

输出:

transform/transform/inputs/review/Placeholder 
transform/transform/inputs/review/Identity 
transform/transform/inputs/label/Placeholder 
transform/transform/inputs/label/Identity 
transform/transform_1/inputs/review/Placeholder 
transform/transform_1/inputs/review/Identity 
transform/transform_1/inputs/label/Placeholder 
transform/transform_1/inputs/label/Identity 
save_1/StringJoin/inputs_1 
save_2/StringJoin/inputs_1

2 个答案:

答案 0 :(得分:0)

错误告诉您You must feed a value for placeholder tensor 'Placeholder':这意味着在您为该占位符提供值之前,无法构建图表。

在你的python代码中,你可以在以下行中输入它:

input_dict["inputs"]: ("I went and saw this movie last night",)

事实上,input_dict["inputs"]评估为:<tf.Tensor 'Placeholder:0' shape=(?,) dtype=string>

在Go代码中,您正在寻找一个名为save_1/StringJoin/inputs_1的张量,它不是占位符。

遵循的规则是:在Python和Python中使用相同的输入。去。

要解决这个问题,您只需从图中提取名为Placeholder的占位符(就像在python中一样),然后使用它。

m.Graph.Operation("Placeholder").Output(0): t,

另外,我建议您在tensorflow API周围使用更完整且易于使用的包装器:tfgo

答案 1 :(得分:0)

还有一件事。我阅读了TF文档并找到了这个topic

有助于找到正确的输入/输出键,例如响应:

The given SavedModel SignatureDef contains the following input(s): inputs['inputs'] tensor_info:
    dtype: DT_STRING
    shape: (-1)
    name: Placeholder:0

PS /张贴作为焦点的答案