我在State Farm Distracted Driver Dataset上训练了Keras模型,并使用以下代码将其导出到 .pb 图形文件中:
def export_model(saver, model, input_node_names, output_node_name):
tf.train.write_graph(K.get_session().graph_def, 'out', \
MODEL_NAME + '_graph.pbtxt')
saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')
freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \
False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
"save/restore_all", "save/Const:0", \
'out/frozen_' + MODEL_NAME + '.pb', True, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
input_graph_def.ParseFromString(f.read())
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, input_node_names, [output_node_name],
tf.float32.as_datatype_enum)
with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
print("graph saved!")
问题是当我在android上运行导出的模型时,它总是会预测错误的类。当在source.的MNIST数据集上训练模型的给定.pb文件时,以下相同代码可以预测正确的标签。
public class MainActivity extends AppCompatActivity {
ImageView imageView;
Button button1;
TextView label;
private static final int PixelWidth = 100;
//private List<Classifier> mClassifiers = new ArrayList<>();
public static final int cam_req = 999; // used in OpenCamera Function
public Classifier c11;
private static final int PICK_IMAGE = 1;
Uri imageUri;
private boolean isLoaded= true;
private float[] imageNormalizedPixels;
private int [] imageBitmapPixels;
private byte[] bytearray;
@Override
protected void onCreate(Bundle savedInstanceState) {
c11 = new ModelClassifier();
loadModel();
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
imageView = (ImageView) findViewById(R.id.imageView);
label = (TextView)findViewById(R.id.label_output);
button1 = (Button) findViewById(R.id.button);
button1.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
openGallery();
}
}
);
}
private void loadModel() {
//The Runnable interface is another way in which you can implement multi-threading other than extending the
// //Thread class due to the fact that Java allows you to extend only one class. Runnable is just an interface,
// //which provides the method run.
// //Threads are implementations and use Runnable to call the method run().
new Thread(new Runnable() {
@Override
public void run() {
try {
c11=ModelClassifier.create(getAssets(), "Keras", "opt_keras-10grey.pb", "labels.txt", PixelWidth,
"input_3", "softmax_1/Softmax", false);
} catch (final Exception e) {
//if they aren't found, throw an error!
label.setText("Model not loaded");
isLoaded = false;
throw new RuntimeException("Error initializing classifiers!", e);
}
}
}).start();
}
private void openGallery() {
if (isLoaded) {
Intent gallery = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI);
startActivityForResult(gallery, PICK_IMAGE);
}
}
public void OpenCamera(View view) {
if (isLoaded) {
Intent intnt = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
startActivityForResult(intnt, cam_req);
}
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent Data) {
super.onActivityResult(requestCode, resultCode, Data);
if (resultCode == RESULT_OK && requestCode == PICK_IMAGE) {
try {
final Uri imageUri = Data.getData();
final InputStream imageStream = getContentResolver().openInputStream(imageUri);
final Bitmap selectedImage = BitmapFactory.decodeStream(imageStream);
ImageView imageView = (ImageView) findViewById(R.id.imageView);
imageView.setImageBitmap(selectedImage);
imageView.invalidate();
BitmapDrawable drawable = (BitmapDrawable) imageView.getDrawable();
Bitmap image = Bitmap.createScaledBitmap(selectedImage, PixelWidth, PixelWidth, false);
float [] pixels = getPixels(image);
if (pixels!=null) {
System.out.println("It is gallery image");
final Classification returned = c11.recognize(pixels);
System.out.println("this is the predicted class: "+returned.getLabel());
label.setText(returned.getLabel());
}
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
public float[] getPixels(Bitmap bitmap)
{
// Get 100x100 pixel data from bitmap
int[] pixels = new int[PixelWidth * PixelWidth];
bitmap.getPixels(pixels, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
float[] retPixels = new float[pixels.length];
for (int i = 0; i < pixels.length; ++i) {
// Set 0 for white and 255 for black pixel
int pix = pixels[i];
retPixels[i] = pix /255f;
}
return retPixels;
}
}
ModelClassifier.java 类如下:
public class ModelClassifier implements Classifier {
private static final float THRESHOLD = 0.1f;
private TensorFlowInferenceInterface tfHelper;
private String name;
private String InputName;
private String OutputName;
private int inputsize;
private boolean feedKeepProb;
private List<String> Labels;
private float[] Output;
private String OutputNames[];
/*public ModelClassier(Context a){
Intent a1=a;
}*/
private static List<String> readLabels(AssetManager am, String fileName) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(am.open(fileName)));
String line;
List<String> labels = new ArrayList<>();
while ((line = br.readLine()) != null) {
labels.add(line);
}
br.close();
return labels;
}
public static ModelClassifier create(AssetManager assetManager, String name,
String modelPath, String labelFile, int inputSize, String inputName, String outputName,
boolean feedKeepProb) throws IOException {
ModelClassifier c = new ModelClassifier();
c.name = name;
c.InputName = inputName;
c.OutputName = outputName;
c.Labels = readLabels(assetManager, labelFile);
System.out.println("These are the labels"+c.Labels.toString());
c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
System.out.println("ModelPath is++++++++++" + modelPath);
int numClasses = 10;
c.inputsize = inputSize;
c.OutputNames = new String[]{outputName};
c.OutputName = outputName;
c.Output = new float[numClasses];
c.feedKeepProb = feedKeepProb;
System.out.println("Model is finally loaded++++++++++");
return c;
}
//@Override
public String name() {
return name;
}
@Override
public Classification recognize(float[] pixels) {
//using the interface
//give it the input name, raw pixels from the drawing,
//input size
tfHelper.feed(InputName, pixels, new long[]{1, inputsize, inputsize, 1});
//probabilities
if (feedKeepProb) {
tfHelper.feed("keep_prob", new float[]{1});
System.out.println("Infeedprob++++++++++");
}
//get the possible outputs
tfHelper.run(OutputNames);
tfHelper.fetch(OutputName, Output);
System.out.println("afterfetch++++++++++");
// Find the best classification
//for each output prediction
//if its above the threshold for accuracy we predefined
//write it out to the view
Classification ans = new Classification();
for (int i = 0; i < Output.length; ++i) {
if (Output[i] > THRESHOLD && Output[i] > ans.getConf()) {
ans.update(Output[i], Labels.get(i));
}
}
return ans;
}
请注意,模型输入形状为[100,100,1]
请告诉我我在哪里犯错。我尝试过在不同数量的班级上进行培训,但是输出却完全不同。