Tflite预测与冻结推理图预测完全不同

时间:2019-08-10 12:59:56

标签: android machine-learning tensorflow-lite

我从事眼部区域定位项目,并训练了自己的自定义数据集以使用Tensorflow库创建模型。

我生成.ckpts文件(模型),得到可接受的结果,将此模型转换为.pb冻结推断模型,并在网络摄像头上测试了冻结模型的准确性,并且工作正常。

问题是当我将.pb模型转换为tflite模型时。 使用Android应用程序和MLkit firebase自定义模型时,结果非常糟糕。

我已在GitHub(Tensorflow repo和firebase repo)Github link上发布了此问题,但没有得到任何答案,我确实需要解决此问题。

  • 使用冻结推理(python脚本)的结果:

result with frozen inference graph

  • 使用tflite模型的结果(Android设备)

result using tflite model

那是一些Java代码(Android):

 private void useInferenceResult(float[] probabilities) throws IOException {
        // [START mlkit_use_inference_result]
        String[] result=new String[80];
        String x="";
        String y="";
        ArrayList<Point> listpoint= new ArrayList<Point>();
        double viewWidth = canvas.getWidth();
        double viewHeight = canvas.getHeight();
        double imageWidth = mutableBitmap.getWidth();
        double imageHeight = mutableBitmap.getHeight();
        Log.i("viewWidth","viewwidth "+viewWidth);
        Log.i("viewHeight","viewheight "+viewHeight);
        Log.i("imagewidth","imagewidth "+imageWidth);
        Log.i("imaageHeigh","imageheigh "+imageHeight);
        double scale = Math.min(viewWidth / imageWidth, viewHeight / imageHeight);
        Log.i("Scale","Scale"+scale);
        try {
            for (int i = 0; i < probabilities.length; i++) {

                Log.i("MLKit", String.format("%1.8f", probabilities[i]));
                float i1 = probabilities[i];
                Log.i("floaaat", "" + i1);
                x = String.format("%1.8f", probabilities[i]);
                y = String.format("%1.8f", probabilities[i + 1]);
                Point p = new Point(x, y);
                i = i + 1;
                p.setX(x);
                p.setY(y);
                listpoint.add(p);              
            }
        }
        catch(Exception exc){
            Log.e("Exception","Error:  "+exc);
        }
        for(int j=0;j<listpoint.size();j++){
            try {

                String xx = listpoint.get(j).getX();
                String yy = listpoint.get(j).getY();
                xx=xx.replace(",",".");
                yy=yy.replace(",",".");          
                float xx1 = Float.parseFloat(xx);
                float yy1 = Float.parseFloat(yy);
                Log.i("Float results", "point_" + j + "(" + xx1 + ", " + yy1 + ")");
                Log.i("Scale","Scale  "+scale);
                drawpoint(image2, (xx1*(float)scale*293) , (yy1*(float)scale*293) , 1);
            }
            catch(Exception esa){
                Log.e("Exception","Exception:  "+esa);
                Toast.makeText(this, "Exception"+esa, Toast.LENGTH_SHORT).show();
            }

        }

        }

// drawbitmap function

  private double drawBitmap(Canvas canvas) {
        double viewWidth = canvas.getWidth();
        double viewHeight = canvas.getHeight();
        double imageWidth = mutableBitmap.getWidth();
        double imageHeight = mutableBitmap.getHeight();
        double scale = Math.min(viewWidth / imageWidth, viewHeight / imageHeight);

        Rect destBounds = new Rect(0, 0, (int)(imageWidth * scale), (int)(imageHeight * scale));
        canvas.drawBitmap(mutableBitmap, null, destBounds, null);
        return scale;
    }

如何解决此错误?

0 个答案:

没有答案