在Firebase图像标签中如何将TFLite模型与uint8的输入和输出数据类型一起使用

时间:2019-07-31 13:06:38

标签: java python android tensorflow-lite firebase-mlkit

我在Firebase中使用以下输入和输出使用AutoML创建了图像分类器模型

[  1 224 224   3]
<class 'numpy.uint8'>
[ 1 11]
<class 'numpy.uint8'>

但是FirebaseModelDataType没有uint8数据类型。我该怎么办? 它仅支持INT32,FLOAT32,BYTE和LONG

interpreter = FirebaseModelInterpreter.getInstance(options);
        inputOutputOptions = new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 11})
                .build();

此代码将不会运行,因为模型的输入和输出为uint8

1 个答案:

答案 0 :(得分:1)

我终于使它工作了原来,我使用AutoML模型的方式与自定义模型的使用方式不同。

private void startLabel() {
    FirebaseLocalModel localModel = new FirebaseLocalModel.Builder("my_local_model")
            .setAssetFilePath("manifest.json")
            .build();

    FirebaseModelManager.getInstance().registerLocalModel(localModel);
    timer = new Timer();
    timer.schedule(new TimerTask() {
        @Override
        public void run() {
            FirebaseVisionImage image = FirebaseVisionImage.fromBitmap(textureView.getBitmap());
            FirebaseVisionOnDeviceAutoMLImageLabelerOptions labelerOptions = new FirebaseVisionOnDeviceAutoMLImageLabelerOptions.Builder()
                    .setLocalModelName("my_local_model")
                    .setConfidenceThreshold(0.55f)
                    .build();
            try {

                FirebaseVisionImageLabeler labeler = FirebaseVision.getInstance().getOnDeviceAutoMLImageLabeler(labelerOptions);
                labeler.processImage(image)
                        .addOnSuccessListener(new OnSuccessListener<List<FirebaseVisionImageLabel>>() {
                            @Override
                            public void onSuccess(List<FirebaseVisionImageLabel> firebaseVisionImageLabels) {
                                if(!firebaseVisionImageLabels.isEmpty()){
                                    MoneyReader.this.result.setText(firebaseVisionImageLabels.get(0).getText());
                                    if(isTTSReady){
                                        tts.speak(firebaseVisionImageLabels.get(0).getText(), TextToSpeech.QUEUE_ADD, null, "DEFAULT");
                                    }
                                }else{
                                    status.setText("Nothing Recognized");
                                }
                            }
                        })
                        .addOnFailureListener(new OnFailureListener() {
                            @Override
                            public void onFailure(@NonNull Exception e) {
                                Toast.makeText(MoneyReader.this, e.getMessage(), Toast.LENGTH_SHORT).show();
                            }
                        });
            } catch (FirebaseMLException e) {
                Toast.makeText(MoneyReader.this, e.getMessage(), Toast.LENGTH_SHORT).show();
            }
        }
    }, 0, 2000);
}