我现在正在使用tensorflow(python)来训练我的模型,并希望在线使用tensorflow(java)来推断结果。
计算图有一个返回形状[1,16] 结果的操作,张量中的每个元素都是一个字符串。现在我想将结果转换为整个字符串。
我创建一个ByteBuffer,并调用Tensor.writeTo在缓冲区中写入数据。但是当我解码最终缓冲区时,它在标题中有一些意想不到的字符,我猜最后的字节可能包含一些张量元信息。
Tensor predictedTensor = result.get(0);
ByteBuffer bb = ByteBuffer.allocate(predictedTensor.numBytes());
predictedTensor.writeTo(bb);
String predictedTokens = null;
byte[] bbArray = bb.array();
predictedTokens = new String(bbArray, "UTF-8");
结果是这样的:第一部分是一些不正确的代码,最后一部分是正确的。
& * ? * C J M X & * ? * C J M X hello,world!
我想也许带有形状的Tensor(1,16)在字节中有元信息,但我不知道如何获取我需要的部分。
有没有人知道如何在java tensorflow接口中将多维张量转换为java字符串?
答案 0 :(得分:0)
我为此找到了一个workaroud! 训练模型时,我在张量上调用 tf.reduce_join 形状(1,16)得到一个标量。 当在java语言中进行推理时,我只需获取该标量节点,并调用 tensor.byteValue()来获取张量字节。如果没有标题代码,它将返回纯粹的结果。
答案 1 :(得分:0)
如果操作的结果具有形状[1, 16]
,则表示它生成16个不同的字符串,而不是一个字符串。
最近才添加了对Java中多维字符串张量的支持(github commit),并且不包含在TensorFlow 1.3及更早版本的预构建二进制文件中。您必须从源代码构建或等待TensorFlow 1.4版本。
使用该功能,您应该可以使用以下内容解码(1, 16)
形状张量:
Tensor predictedTensor = result.get(0);
byte[][][] predictedTokenBytes = predictedTensor.copyTo(new byte[1][16][]);
String[] predictedTokens = new String[16];
for (int i = 0; i < 16; ++i) {
// This works under the assumption that the model is actually
// producing UTF-8 strings
predictedTokens[i] = new String(predictedTokenBytes[0][i], "UTF-8");
}
如果你真的需要一个字符串,那么你可以使用tf.reduce_join
让模型将16个字符串合并为一个,然后提取标量。