将数据传递给Java中的Tensorflow模型

时间:2018-07-19 18:15:35

标签: java tensorflow

我正在尝试使用我在python中训练的Tensorflow模型对Scala中的数据进行评分(使用TF Java API)。对于模型,我使用了此regression example,唯一的变化是我从asText=True删除了export_savedmodel

我的Scala片段:

  val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
  val s = b.session()

  // output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
  val input = "0.5,1,ax01,bx02"

  val inputTensor = Tensor.create(input.getBytes("UTF-8"))

  val result = s.runner()
    .feed("csv_rows", inputTensor)
    .fetch("dnn/logits/BiasAdd")
    .run()
    .get(0)

运行时,出现以下错误:

Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
 [[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)

我认为我准备输入Tensor的方式存在问题,但是我仍然坚持如何最好地调试它。

1 个答案:

答案 0 :(得分:1)

错误消息表明某些操作中输入张量的形状不是预期的。

看看您链接到的Python笔记本(尤其是第8a和8c节),似乎输入张量应该是字符串张量的“批”,而不是单个字符串张量。

您可以通过比较Scala和Python程序中的张量的形状(scala中的inputTensor.shape()与Python笔记本中提供给csv_rows的{​​{1}}的形状)来观察到这一点

由此,看来predict_fn想要成为字符串的向量,而不是单个标量字符串。为此,您需要执行以下操作:

inputTensor

希望有帮助