导出keras模型时预测不正确

时间:2019-02-24 19:20:29

标签: android keras tensorflow-lite

我在State Farm Distracted Driver Dataset上训练了Keras模型,并使用以下代码将其导出到 .pb 图形文件中:

def export_model(saver, model, input_node_names, output_node_name):
    tf.train.write_graph(K.get_session().graph_def, 'out', \
        MODEL_NAME + '_graph.pbtxt')

    saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')

    freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \
        False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
        "save/restore_all", "save/Const:0", \
        'out/frozen_' + MODEL_NAME + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")

问题是当我在android上运行导出的模型时,它总是会预测错误的类。当在source.的MNIST数据集上训练模型的给定.pb文件时,以下相同代码可以预测正确的标签。

public class MainActivity extends AppCompatActivity {
    ImageView imageView;
    Button button1;
    TextView label;
    private static final int PixelWidth = 100;
    //private List<Classifier> mClassifiers = new ArrayList<>();
    public static final int cam_req = 999; // used in OpenCamera Function
    public Classifier c11;
    private static final int PICK_IMAGE = 1;
    Uri imageUri;
    private boolean isLoaded= true;
    private float[] imageNormalizedPixels;
    private int [] imageBitmapPixels;
    private byte[] bytearray;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        c11 = new ModelClassifier();
        loadModel();

        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        imageView = (ImageView) findViewById(R.id.imageView);
        label = (TextView)findViewById(R.id.label_output);
        button1 = (Button) findViewById(R.id.button);
        button1.setOnClickListener(new View.OnClickListener() {
                                       @Override
                                       public void onClick(View v) {
                                           openGallery();
                                       }
                                   }

        );


    }

    private void loadModel() {
        //The Runnable interface is another way in which you can implement multi-threading other than extending the
        // //Thread class due to the fact that Java allows you to extend only one class. Runnable is just an interface,
        // //which provides the method run.
        // //Threads are implementations and use Runnable to call the method run().
        new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    c11=ModelClassifier.create(getAssets(), "Keras", "opt_keras-10grey.pb", "labels.txt", PixelWidth,
                          "input_3", "softmax_1/Softmax", false);

                } catch (final Exception e) {
                    //if they aren't found, throw an error!
                    label.setText("Model not loaded");
                    isLoaded = false;
                    throw new RuntimeException("Error initializing classifiers!", e);
                }
            }
        }).start();
    }



    private void openGallery() {
        if (isLoaded) {
            Intent gallery = new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI);
            startActivityForResult(gallery, PICK_IMAGE);
        }

    }

    public void OpenCamera(View view) {

        if (isLoaded) {

            Intent intnt = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
            startActivityForResult(intnt, cam_req);
        }
    }


    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent Data) {
        super.onActivityResult(requestCode, resultCode, Data);
        if (resultCode == RESULT_OK && requestCode == PICK_IMAGE) {

            try {
                final Uri imageUri = Data.getData();
                final InputStream imageStream = getContentResolver().openInputStream(imageUri);
                final Bitmap selectedImage = BitmapFactory.decodeStream(imageStream);
                ImageView imageView = (ImageView) findViewById(R.id.imageView);
                imageView.setImageBitmap(selectedImage);
                imageView.invalidate();
                BitmapDrawable drawable = (BitmapDrawable) imageView.getDrawable();

                Bitmap image = Bitmap.createScaledBitmap(selectedImage, PixelWidth, PixelWidth, false);

                float [] pixels = getPixels(image);

                if (pixels!=null) {
                    System.out.println("It is gallery image");

                    final Classification returned = c11.recognize(pixels);
                    System.out.println("this is the predicted class: "+returned.getLabel());
                    label.setText(returned.getLabel());


                }

            } catch (FileNotFoundException e) {
                e.printStackTrace();
            }

   }


    public float[] getPixels(Bitmap bitmap)
    {

        // Get 100x100 pixel data from bitmap
        int[] pixels = new int[PixelWidth * PixelWidth];
        bitmap.getPixels(pixels, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());

        float[] retPixels = new float[pixels.length];
        for (int i = 0; i < pixels.length; ++i) {
            // Set 0 for white and 255 for black pixel
            int pix = pixels[i];
            retPixels[i] = pix /255f;
        }
        return retPixels;
    }

    }

ModelClassifier.java 类如下:



public class ModelClassifier implements Classifier {

    private static final float THRESHOLD = 0.1f;
    private TensorFlowInferenceInterface tfHelper;

    private String name;
    private String InputName;
    private String OutputName;

    private int inputsize;

    private boolean feedKeepProb;

    private List<String> Labels;
    private float[] Output;

    private String OutputNames[];


   /*public  ModelClassier(Context a){
       Intent a1=a;
    }*/

    private static List<String> readLabels(AssetManager am, String fileName) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(am.open(fileName)));

        String line;
        List<String> labels = new ArrayList<>();
        while ((line = br.readLine()) != null) {
            labels.add(line);
        }

        br.close();
        return labels;
    }

    public static ModelClassifier create(AssetManager assetManager, String name,
                                         String modelPath, String labelFile, int inputSize, String inputName, String outputName,
                                         boolean feedKeepProb) throws IOException {

        ModelClassifier c = new ModelClassifier();


        c.name = name;

        c.InputName = inputName;
        c.OutputName = outputName;


        c.Labels = readLabels(assetManager, labelFile);
        System.out.println("These are the labels"+c.Labels.toString());

        c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
        System.out.println("ModelPath is++++++++++" + modelPath);
        int numClasses = 10;


        c.inputsize = inputSize;


        c.OutputNames = new String[]{outputName};

        c.OutputName = outputName;
        c.Output = new float[numClasses];

        c.feedKeepProb = feedKeepProb;
        System.out.println("Model is finally loaded++++++++++");
        return c;
    }


    //@Override
    public String name() {
        return name;
    }


    @Override
    public Classification recognize(float[] pixels) {

        //using the interface
        //give it the input name, raw pixels from the drawing,
        //input size
        tfHelper.feed(InputName, pixels, new long[]{1, inputsize, inputsize, 1});

        //probabilities
        if (feedKeepProb) {
            tfHelper.feed("keep_prob", new float[]{1});
            System.out.println("Infeedprob++++++++++");
        }
        //get the possible outputs
        tfHelper.run(OutputNames);


        tfHelper.fetch(OutputName, Output);
        System.out.println("afterfetch++++++++++");
        // Find the best classification
        //for each output prediction
        //if its above the threshold for accuracy we predefined
        //write it out to the view
        Classification ans = new Classification();
        for (int i = 0; i < Output.length; ++i) {
            if (Output[i] > THRESHOLD && Output[i] > ans.getConf()) {
                ans.update(Output[i], Labels.get(i));
            }

        }

        return ans;
    }

请注意,模型输入形状为[100,100,1]

请告诉我我在哪里犯错。我尝试过在不同数量的班级上进行培训,但是输出却完全不同。

0 个答案:

没有答案