Tomcat中的相同tensorflow模型推断得到了与简单java应用程序

时间:2017-09-27 10:18:32

标签: java python tomcat tensorflow

我们正在 Tomcat7(java1.8)中部署tensorflow模型(seq2seq问题回答),而在调试时,我们只是使用简单的java Application(public static void main()函数)来测试模型推断结果。简单java应用程序中的推理结果与python原始版本中的相同。 但是当我们在Tomcat中启动整个包(WAR)时,它会得到完全不同的结果,而推理代码/测试输入语句/模型文件都是相同的。

有人能给我们一些关于这个问题的提示吗?

  1. 简单的java应用程序(public static void main()函数)获得与python tensorflow版本推断结果相同的结果。我们将它们视为正确的。
  2. Tomcat加载的模型获得不同的结果。结果看起来像普通句子,但在考虑问题时,答案意义非常糟糕。
  3. 模型文件(protobuf)/ java代码/测试输入句子在上述两种情况下是相同的。
  4. 推断概率为1.0f。
  5. 模型加载功能:

    @Override
    public boolean reload(String modelURL) {
        logger.info("tensorflow version:{}", TensorFlow.version());
        try {
            logger.info("start to download model path:{}", modelURL);
            //TODO: download model
            logger.info("start to load model path:{} tag:{}", MODEL_PATH, MODEL_TAG);
            bundle = SavedModelBundle.load(MODEL_PATH, MODEL_TAG);
            session = bundle.session();
            logger.info("finish loading model!");
    
        } catch(Exception e) {
            logger.error("reload model exception:", e);
            return false;
        }
    
        return true;
    }
    

    推理代码:

        @Override
    public String predict(String query, String candidateAnswer) {
        if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
            logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
            return null;
        }
        String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
        String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);
    
        try(Tensor queryTensor = Tensor.create(queryPad.getBytes());
            Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
            Tensor candidateTensor = Tensor.create(candidatePad.getBytes());
            Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
        {
            List<Tensor> result = session.runner()
                    .feed("source_tokens", queryTensor)
                    .feed("source_len", queryLenTensor)
                    .feed("source_candidate_tokens", candidateTensor)
                    .feed("source_candidate_len", candidateLenTensor)
                    .fetch("model/att_seq2seq/predicted_tokens_scalar")
                    .run();
    
            Tensor predictedTensor = result.get(0);
            String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
            logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
            return predictedTokens;
        } catch (Exception e) {
            logger.error("exception:", e);
        }
    
        return null;
    }
    

1 个答案:

答案 0 :(得分:0)

是的,这是编码问题。当我们在简单的java应用程序(public static void main())中启动模型时,它的默认编码是UTF-8,同时调用getBytes()。但是当我们在tomcat中启动模型时,其编码方案是ISO-8859-1。

Tensor queryTensor = Tensor.create( queryPad.getBytes(&#34; UTF-8&#34;)

Tensor candidateTensor = Tensor.create( candidatePad.getBytes(&#34; UTF-8&#34;)

    @Override
public String predict(String query, String candidateAnswer) {
    if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
        logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
        return null;
    }
    String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
    String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);

    try(Tensor queryTensor = Tensor.create(queryPad.getBytes("UTF-8"));
        Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
        Tensor candidateTensor = Tensor.create(candidatePad.getBytes("UTF-8"));
        Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
    {
        List<Tensor> result = session.runner()
                .feed("source_tokens", queryTensor)
                .feed("source_len", queryLenTensor)
                .feed("source_candidate_tokens", candidateTensor)
                .feed("source_candidate_len", candidateLenTensor)
                .fetch("model/att_seq2seq/predicted_tokens_scalar")
                .run();

        Tensor predictedTensor = result.get(0);
        String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
        logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
        return predictedTokens;
    } catch (Exception e) {
        logger.error("exception:", e);
    }

    return null;
}