我正在使用Java服务通过Python学习的Tensorflow模型。该模型有两个输入。代码如下:
def predict(float32InputShape: (Long, Long),
float32Inputs: Seq[Seq[Float]],
uint8InputShape: (Long, Long),
uint8Inputs: Seq[Seq[Byte]]
): Array[Float] = {
val float32Input = Tensor.create(
Array(float32InputShape._1, float32InputShape._2),
FloatBuffer.wrap(float32Inputs.flatten.toArray)
)
val uint8Input = Tensor.create(
classOf[UInt8],
Array(uint8InputShape._1, uint8InputShape._2),
ByteBuffer.wrap(uint8Inputs.flatten.toArray)
)
val tfResult = session
.runner()
.feed("serving_default_float32_Input", float32Input)
.feed("serving_default_uint8_Input", uint8Input)
.fetch("PartitionedCall")
.run()
.get(0)
.expect(classOf[java.lang.Float])
tfResult
}
我想做的是通过传递输入,例如Python中的feed_dict,来重构该方法,使其更通用。就是这样:
def predict2(inputs: Map[String, Seq[Seq[Float]]]): Array[Float] = {
...
session
.runner()
.feed(inputs)
...
}
inputs
映射的键将是输入层的名称。除非我创建一个宏(我要避免),否则无法使用feed
方法执行此操作。
是否可以使用Tensorflow的Java API(我使用的是TF 2.0)来做到这一点?
修改: 我找到了解决方案(由于@geometrikal的回答),代码在Scala中,但在Java中要做到同样难。
val runnerWithInputLayers = inputs.foldLeft(session.runner()) {
case (sess, (layerName, array)) =>
val tensor = createTensor(array)
sess.feed(layerName, tensor)
}
val output = runnerWithInputLayers
.fetch(outputLayer)
.run()
.get(0)
.expect(Float.getClass)
可能是因为.feed
方法返回了Session.Runner
并且提供了输入层。
答案 0 :(得分:0)
您可以循环输入每个。如果不太熟悉Java脚本,但伪代码就像
例如
val tfResult = session.runner()
for(key, value : inputs) {
tfResult = tfResult(key, value)
}
tfResult = tfResult.fetch("PartitionedCall")
.run()
.get(0)
.expect(classOf[java.lang.Float])
请记住,您可以随时断开功能链,例如result = foo.bar().baz().qux()
可以写成temp = foo.bar().baz(); result = temp.qux()