我们正在 Tomcat7(java1.8)中部署tensorflow模型(seq2seq问题回答),而在调试时,我们只是使用简单的java Application(public static void main()函数)来测试模型推断结果。简单java应用程序中的推理结果与python原始版本中的相同。 但是当我们在Tomcat中启动整个包(WAR)时,它会得到完全不同的结果,而推理代码/测试输入语句/模型文件都是相同的。
有人能给我们一些关于这个问题的提示吗?
模型加载功能:
@Override
public boolean reload(String modelURL) {
logger.info("tensorflow version:{}", TensorFlow.version());
try {
logger.info("start to download model path:{}", modelURL);
//TODO: download model
logger.info("start to load model path:{} tag:{}", MODEL_PATH, MODEL_TAG);
bundle = SavedModelBundle.load(MODEL_PATH, MODEL_TAG);
session = bundle.session();
logger.info("finish loading model!");
} catch(Exception e) {
logger.error("reload model exception:", e);
return false;
}
return true;
}
推理代码:
@Override
public String predict(String query, String candidateAnswer) {
if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
return null;
}
String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);
try(Tensor queryTensor = Tensor.create(queryPad.getBytes());
Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
Tensor candidateTensor = Tensor.create(candidatePad.getBytes());
Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
{
List<Tensor> result = session.runner()
.feed("source_tokens", queryTensor)
.feed("source_len", queryLenTensor)
.feed("source_candidate_tokens", candidateTensor)
.feed("source_candidate_len", candidateLenTensor)
.fetch("model/att_seq2seq/predicted_tokens_scalar")
.run();
Tensor predictedTensor = result.get(0);
String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
return predictedTokens;
} catch (Exception e) {
logger.error("exception:", e);
}
return null;
}
答案 0 :(得分:0)
是的,这是编码问题。当我们在简单的java应用程序(public static void main())中启动模型时,它的默认编码是UTF-8,同时调用getBytes()。但是当我们在tomcat中启动模型时,其编码方案是ISO-8859-1。
Tensor queryTensor = Tensor.create( queryPad.getBytes(&#34; UTF-8&#34;))
Tensor candidateTensor = Tensor.create( candidatePad.getBytes(&#34; UTF-8&#34;))
@Override
public String predict(String query, String candidateAnswer) {
if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
return null;
}
String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);
try(Tensor queryTensor = Tensor.create(queryPad.getBytes("UTF-8"));
Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
Tensor candidateTensor = Tensor.create(candidatePad.getBytes("UTF-8"));
Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
{
List<Tensor> result = session.runner()
.feed("source_tokens", queryTensor)
.feed("source_len", queryLenTensor)
.feed("source_candidate_tokens", candidateTensor)
.feed("source_candidate_len", candidateLenTensor)
.fetch("model/att_seq2seq/predicted_tokens_scalar")
.run();
Tensor predictedTensor = result.get(0);
String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
return predictedTokens;
} catch (Exception e) {
logger.error("exception:", e);
}
return null;
}