我正在尝试使用我在python中训练的Tensorflow模型对Scala中的数据进行评分(使用TF Java API)。对于模型,我使用了此regression example,唯一的变化是我从asText=True
删除了export_savedmodel
。
我的Scala片段:
val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
val s = b.session()
// output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
val input = "0.5,1,ax01,bx02"
val inputTensor = Tensor.create(input.getBytes("UTF-8"))
val result = s.runner()
.feed("csv_rows", inputTensor)
.fetch("dnn/logits/BiasAdd")
.run()
.get(0)
运行时,出现以下错误:
Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
[[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
我认为我准备输入Tensor的方式存在问题,但是我仍然坚持如何最好地调试它。
答案 0 :(得分:1)
错误消息表明某些操作中输入张量的形状不是预期的。
看看您链接到的Python笔记本(尤其是第8a和8c节),似乎输入张量应该是字符串张量的“批”,而不是单个字符串张量。
您可以通过比较Scala和Python程序中的张量的形状(scala中的inputTensor.shape()
与Python笔记本中提供给csv_rows
的{{1}}的形状)来观察到这一点
由此,看来predict_fn
想要成为字符串的向量,而不是单个标量字符串。为此,您需要执行以下操作:
inputTensor
希望有帮助