我最近开始在Python中使用Tensorflow建模深度学习算法。我希望能够使用Tensorflow Java API在Scala中使用SavedModel。 但是,这是我倾向于将其集成到代码中时遇到的错误:
Program result = Failure(java.lang.IllegalArgumentException: Input to reshape is a tensor with 4 values, but the requested shape has 1
[[{{node graph/Reshape}} = Reshape[T=DT_BOOL, Tshape=DT_INT32, _output_shapes=[[?,?,1]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](graph/SequenceMask/Less, graph/Reshape/shape)]])
这也是我的Scala代码打印的模型元数据:
model metadata: ModelMetadata(Map(serving_default -> SignatureMetadata(tensorflow/serving/predict,Map(chars -> TensorMetadata(chars:0,chars,DTypeString,List(-1, -1, -1)), words -> TensorMetadata(words:0,words,DTypeString,List(-1, -1)), nchars -> TensorMetadata(nchars:0,nchars,DTypeInt32,List(-1, -1)), nwords -> TensorMetadata(nwords:0,nwords,DTypeInt32,List(-1))),Map(pred_ids_ema -> TensorMetadata(cond_11/Merge:0,cond_11/Merge,DTypeInt32,List(-1, -1)), pred_ids -> TensorMetadata(cond/Merge:0,cond/Merge,DTypeInt32,List(-1, -1)), tags_ema -> TensorMetadata(index_to_string_Lookup_1:0,index_to_string_Lookup_1,DTypeString,List(-1, -1)), tags -> TensorMetadata(index_to_string_Lookup:0,index_to_string_Lookup,DTypeString,List(-1, -1))))))
serving signature: SignatureMetadata(tensorflow/serving/predict,Map(chars -> TensorMetadata(chars:0,chars,DTypeString,List(-1, -1, -1)), words -> TensorMetadata(words:0,words,DTypeString,List(-1, -1)), nchars -> TensorMetadata(nchars:0,nchars,DTypeInt32,List(-1, -1)), nwords -> TensorMetadata(nwords:0,nwords,DTypeInt32,List(-1))),Map(pred_ids_ema -> TensorMetadata(cond_11/Merge:0,cond_11/Merge,DTypeInt32,List(-1, -1)), pred_ids -> TensorMetadata(cond/Merge:0,cond/Merge,DTypeInt32,List(-1, -1)), tags_ema -> TensorMetadata(index_to_string_Lookup_1:0,index_to_string_Lookup_1,DTypeString,List(-1, -1)), tags -> TensorMetadata(index_to_string_Lookup:0,index_to_string_Lookup,DTypeString,List(-1, -1))))
serving signature inputs: Map(chars -> TensorMetadata(chars:0,chars,DTypeString,List(-1, -1, -1)), words -> TensorMetadata(words:0,words,DTypeString,List(-1, -1)), nchars -> TensorMetadata(nchars:0,nchars,DTypeInt32,List(-1, -1)), nwords -> TensorMetadata(nwords:0,nwords,DTypeInt32,List(-1)))
serving signature outputs: Map(pred_ids_ema -> TensorMetadata(cond_11/Merge:0,cond_11/Merge,DTypeInt32,List(-1, -1)), pred_ids -> TensorMetadata(cond/Merge:0,cond/Merge,DTypeInt32,List(-1, -1)), tags_ema -> TensorMetadata(index_to_string_Lookup_1:0,index_to_string_Lookup_1,DTypeString,List(-1, -1)), tags -> TensorMetadata(index_to_string_Lookup:0,index_to_string_Lookup,DTypeString,List(-1, -1)))
以下是用于馈送我的模型和运行会话的代码(请注意,processFeatures尚未完全编码,缺少一些动态性):
def processFeatures(line: String): (Array[Array[Array[Byte]]], Array[Int], Array[Array[Array[Byte]]], Array[Array[Int]]) = {
val nbWords = line.split(" ").length
val maxNbChars = line.split(" ").map(_.length).foldLeft(0) { (acc, current) =>
if (current < acc) acc
else if (current > acc) current
else acc
}
val words = Array.ofDim[Array[Byte]](1, nbWords)
words(0)(0) = line.getBytes("UTF-8")
logger.info(s"Number of words: $nbWords | Max number of characters: $maxNbChars")
val nChars = Array.ofDim[Int](1, nbWords)
nChars(0)(0) = line.split(" ")(0).length
val chars = Array.ofDim[Byte](1, nbWords, maxNbChars)
chars(0)(0) = line.getBytes("UTF-8")
(words, new Array[Int](nbWords), chars, nChars)
}
val features = processFeatures("toto")
println(Tensor.create(features._1))
println(Tensor.create(features._2))
println(Tensor.create(features._3))
println(Tensor.create(features._4))
val outputs = model.bundle.session.runner
.feed(signature.inputs("words").opName, Tensor.create(features._1))
.feed(signature.inputs("nwords").opName, Tensor.create(features._2))
.feed(signature.inputs("chars").opName, Tensor.create(features._3))
.feed(signature.inputs("nchars").opName, Tensor.create(features._4))
.fetch(output.opName)
.run()
outputs
.asScala
还可以向您展示打印出的张量特性:
STRING tensor with shape [1, 1]
INT32 tensor with shape [1]
STRING tensor with shape [1, 1]
INT32 tensor with shape [1, 1]
非常感谢您的帮助。
致谢