tensorflow java模型推理将fetched tensor转换为string?

时间:2017-09-23 07:30:57

标签: java python string tensorflow tensor

我现在正在使用tensorflow(python)来训练我的模型,并希望在线使用tensorflow(java)来推断结果。

计算图有一个返回形状[1,16] 结果的操作,张量中的每个元素都是一个字符串。现在我想将结果转换为整个字符串。

我创建一个ByteBuffer,并调用Tensor.writeTo在缓冲区中写入数据。但是当我解码最终缓冲区时,它在标题中有一些意想不到的字符,我猜最后的字节可能包含一些张量元信息。

Tensor predictedTensor = result.get(0);
ByteBuffer bb = ByteBuffer.allocate(predictedTensor.numBytes());
predictedTensor.writeTo(bb);
String predictedTokens = null;
byte[] bbArray = bb.array();
predictedTokens = new String(bbArray, "UTF-8");

结果是这样的:第一部分是一些不正确的代码,最后一部分是正确的。

& *  ? *  C J M X & *  ? *  C J M X hello,world!

我想也许带有形状的Tensor(1,16)在字节中有元信息,但我不知道如何获取我需要的部分。

有没有人知道如何在java tensorflow接口中将多维张量转换为java字符串?

2 个答案:

答案 0 :(得分:0)

我为此找到了一个workaroud! 训练模型时,我在张量上调用 tf.reduce_join 形状(1,16)得到一个标量。 当在java语言中进行推理时,我只需获取该标量节点,并调用 tensor.byteValue()来获取张量字节。如果没有标题代码,它将返回纯粹的结果。

答案 1 :(得分:0)

如果操作的结果具有形状[1, 16],则表示它生成16个不同的字符串,而不是一个字符串。

最近才添加了对Java中多维字符串张量的支持(github commit),并且不包含在TensorFlow 1.3及更早版本的预构建二进制文件中。您必须从源代码构建或等待TensorFlow 1.4版本。

使用该功能,您应该可以使用以下内容解码(1, 16)形状张量:

Tensor predictedTensor = result.get(0);
byte[][][] predictedTokenBytes = predictedTensor.copyTo(new byte[1][16][]);
String[] predictedTokens = new String[16];
for (int i = 0; i < 16; ++i) {
  // This works under the assumption that the model is actually
  // producing UTF-8 strings    
  predictedTokens[i] = new String(predictedTokenBytes[0][i], "UTF-8");
}

如果你真的需要一个字符串,那么你可以使用tf.reduce_join让模型将16个字符串合并为一个,然后提取标量。

相关问题