我是android开发的新手,我希望创建一个自定义机器学习模型(具体来说是自定义对象检测器)支持的android应用。
我已经对图像数据进行了CNN培训,并且我拥有必需的“ graph.pb”文件。我也将模型转换为“ .tflite”。我已将文件上传到 Firebase控制台上的自定义模型。我面临的问题是如何实际将此模型集成到应用程序中。我阅读了有关FirebaseRemoteModel,FirebaseModelOptions等的信息,但我不明白如何将其实现到android应用中。
在另一个项目中,我使用了tensorflow-for-poets-2,但看起来,该模型包含了apk文件本身,就像本地模型一样。
如上所述,我是android studio的初学者,我通过参考在线上提供的不同来源提出了一些建议。 我正在开发的android应用基于android studio侧面抽屉模板。我目前仅在应用程序中实现了相机功能。
我的要求是使用相机捕获图像,然后使用Firebase控制台上托管的TensorFlow-lite模型对图像进行预测。
这是android项目的MainActivity.java。因此,大多数字段未链接。
package com.example.firebase_app;
import android.content.Intent;
import android.os.Bundle;
import com.google.android.material.floatingactionbutton.FloatingActionButton;
import com.google.android.material.snackbar.Snackbar;
import android.provider.MediaStore;
import android.view.View;
import androidx.core.view.GravityCompat;
import androidx.appcompat.app.ActionBarDrawerToggle;
import android.view.MenuItem;
import com.google.android.material.navigation.NavigationView;
import androidx.drawerlayout.widget.DrawerLayout;
import androidx.appcompat.app.AppCompatActivity;
import androidx.appcompat.widget.Toolbar;
import android.view.Menu;
public class MainActivity extends AppCompatActivity
implements NavigationView.OnNavigationItemSelectedListener{
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Toolbar toolbar = findViewById(R.id.toolbar);
setSupportActionBar(toolbar);
FloatingActionButton fab = findViewById(R.id.fab);
fab.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
Intent Intent3 =new Intent(MediaStore.INTENT_ACTION_STILL_IMAGE_CAMERA);
startActivity(Intent3);
}
});
DrawerLayout drawer = findViewById(R.id.drawer_layout);
NavigationView navigationView = findViewById(R.id.nav_view);
ActionBarDrawerToggle toggle = new ActionBarDrawerToggle(
this, drawer, toolbar, R.string.navigation_drawer_open, R.string.navigation_drawer_close);
drawer.addDrawerListener(toggle);
toggle.syncState();
navigationView.setNavigationItemSelectedListener(this);
}
@Override
public void onBackPressed() {
DrawerLayout drawer = findViewById(R.id.drawer_layout);
if (drawer.isDrawerOpen(GravityCompat.START)) {
drawer.closeDrawer(GravityCompat.START);
} else {
super.onBackPressed();
}
}
@Override
public boolean onCreateOptionsMenu(Menu menu) {
// Inflate the menu; this adds items to the action bar if it is present.
getMenuInflater().inflate(R.menu.main, menu);
return true;
}
@Override
public boolean onOptionsItemSelected(MenuItem item) {
// Handle action bar item clicks here. The action bar will
// automatically handle clicks on the Home/Up button, so long
// as you specify a parent activity in AndroidManifest.xml.
int id = item.getItemId();
//noinspection SimplifiableIfStatement
if (id == R.id.action_settings) {
return true;
}
return super.onOptionsItemSelected(item);
}
@SuppressWarnings("StatementWithEmptyBody")
@Override
public boolean onNavigationItemSelected(MenuItem item) {
// Handle navigation view item clicks here.
int id = item.getItemId();
if (id == R.id.nav_home) {
// Handle the camera action
} else if (id == R.id.nav_gallery) {
} else if (id == R.id.nav_slideshow) {
} else if (id == R.id.nav_tools) {
} else if (id == R.id.nav_share) {
} else if (id == R.id.nav_send) {
}
DrawerLayout drawer = findViewById(R.id.drawer_layout);
drawer.closeDrawer(GravityCompat.START);
return true;
}
}
在阅读了一些有关Firebase的文档之后,我在Java中创建了一个自定义模型活动类,如下所示:
package com.example.firebase_app;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.os.Build;
import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import android.util.Log;
import com.google.android.gms.tasks.OnFailureListener;
import com.google.android.gms.tasks.OnSuccessListener;
import com.google.firebase.ml.common.FirebaseMLException;
import com.google.firebase.ml.common.modeldownload.FirebaseLocalModel;
import com.google.firebase.ml.common.modeldownload.FirebaseModelDownloadConditions;
import com.google.firebase.ml.common.modeldownload.FirebaseModelManager;
import com.google.firebase.ml.common.modeldownload.FirebaseRemoteModel;
import com.google.firebase.ml.custom.FirebaseModelDataType;
import com.google.firebase.ml.custom.FirebaseModelInputOutputOptions;
import com.google.firebase.ml.custom.FirebaseModelInputs;
import com.google.firebase.ml.custom.FirebaseModelInterpreter;
import com.google.firebase.ml.custom.FirebaseModelOptions;
import com.google.firebase.ml.custom.FirebaseModelOutputs;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
public class CustomModelActivity extends AppCompatActivity {
private void configureHostedModelSource() {
// [START mlkit_cloud_model_source]
FirebaseModelDownloadConditions.Builder conditionsBuilder =
new FirebaseModelDownloadConditions.Builder().requireWifi();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
// Enable advanced conditions on Android Nougat and newer.
conditionsBuilder = conditionsBuilder
.requireCharging()
.requireDeviceIdle();
}
FirebaseModelDownloadConditions conditions = conditionsBuilder.build();
// Build a remote model source object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
FirebaseRemoteModel cloudSource = new FirebaseRemoteModel.Builder("crop-disease-detection")
.enableModelUpdates(true)
.setInitialDownloadConditions(conditions)
.setUpdatesDownloadConditions(conditions)
.build();
FirebaseModelManager.getInstance().registerRemoteModel(cloudSource);
// [END mlkit_cloud_model_source]
}
private void configureLocalModelSource() {
// [START mlkit_local_model_source]
FirebaseLocalModel localSource =
new FirebaseLocalModel.Builder("my_local_model") // Assign a name to this model
.setFilePath("/media/piyush/Disk/CULT WEEK/tensorflow-for-poets-2-master/tf_files/optimized_model.tflite")
.build();
FirebaseModelManager.getInstance().registerLocalModel(localSource);
// [END mlkit_local_model_source]
}
private FirebaseModelInterpreter createInterpreter() throws FirebaseMLException {
// [START mlkit_create_interpreter]
FirebaseModelOptions options = new FirebaseModelOptions.Builder()
.setRemoteModelName("crop-disease-detection")
.setLocalModelName("my_local_model")
.build();
FirebaseModelInterpreter firebaseInterpreter =
FirebaseModelInterpreter.getInstance(options);
// [END mlkit_create_interpreter]
return firebaseInterpreter;
}
private FirebaseModelInputOutputOptions createInputOutputOptions() throws FirebaseMLException {
// [START mlkit_create_io_options]
FirebaseModelInputOutputOptions inputOutputOptions =
new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 224, 224, 3})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 5})
.build();
// [END mlkit_create_io_options]
return inputOutputOptions;
}
private float[][][][] bitmapToInputArray() {
// [START mlkit_bitmap_input]
Bitmap bitmap = getYourInputImage();
bitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
int batchNum = 0;
float[][][][] input = new float[1][224][224][3];
for (int x = 0; x < 224; x++) {
for (int y = 0; y < 224; y++) {
int pixel = bitmap.getPixel(x, y);
// Normalize channel values to [-1.0, 1.0]. This requirement varies by
// model. For example, some models might require values to be normalized
// to the range [0.0, 1.0] instead.
input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 128.0f;
input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 128.0f;
input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 128.0f;
}
}
// [END mlkit_bitmap_input]
return input;
}
private void runInference() throws FirebaseMLException {
FirebaseModelInterpreter firebaseInterpreter = createInterpreter();
float[][][][] input = bitmapToInputArray();
FirebaseModelInputOutputOptions inputOutputOptions = createInputOutputOptions();
// [START mlkit_run_inference]
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
.add(input) // add() as many input arrays as your model requires
.build();
firebaseInterpreter.run(inputs, inputOutputOptions)
.addOnSuccessListener(
new OnSuccessListener<FirebaseModelOutputs>() {
@Override
public void onSuccess(FirebaseModelOutputs result) {
// [START_EXCLUDE]
// [START mlkit_read_result]
float[][] output = result.getOutput(0);
float[] probabilities = output[0];
// [END mlkit_read_result]
// [END_EXCLUDE]
}
})
.addOnFailureListener(
new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
// Task failed with an exception
// ...
}
});
// [END mlkit_run_inference]
}
private void useInferenceResult(float[] probabilities) throws IOException {
// [START mlkit_use_inference_result]
BufferedReader reader = new BufferedReader(
new InputStreamReader(getAssets().open("/media/piyush/Disk/CULT WEEK/tensorflow-for-poets-2-master/tf_files/retrained_labels.txt")));
for (int i = 0; i < probabilities.length; i++) {
String label = reader.readLine();
Log.i("MLKit", String.format("%s: %1.4f", label, probabilities[i]));
}
// [END mlkit_use_inference_result]
}
private Bitmap getYourInputImage() {
// This method is just for show
return Bitmap.createBitmap(0, 0, Bitmap.Config.ALPHA_8);
}
}
AndoridManifest
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.example.firebase_app">
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.CAMERA"> </uses-permission>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
<application>
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="E-Agri"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/AppTheme"
<activity
android:name="MainActivity"
android:label="E-Agri"
android:theme="@style/AppTheme.NoActionBar">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
<activity android:name="com.example.firebase_app.CustomModelActivity" />
</application>
</manifest>
我想知道如何使用Firebase托管的tflite模型对捕获的图像进行推断。任何帮助将不胜感激。即使经过大量搜索,我也找不到能为“自定义模型”提供帮助的资源,所有资源都用于ML-kit API,并且由于我已经训练了CNN,因此我不希望在该项目中使用AutoML在海量数据上。