Android Tensorflow IllegalArgumentException错误

时间:2017-08-22 13:24:49

标签: android tensorflow

我正在使用android studio和tensorflow,android版本进行图像识别。 它不是连续跟踪和识别,只是识别一个图像。 我已经有图形pb和标签txt文件,并设置了所需的设置。 但是有一个大问题。 关于图像,尺寸误差我反复得到同样的错误。 这是错误日志和我的源代码。

1D

我不知道问题出在哪里,第一行,[1,1,299,299,3]。我认为两个299是ImageSize,一个是ImageStd,但我不知道另外1和3是什么...... 我输入的代码与tensorflow github中的官方代码相同,只是更改了一些部分。 这是MainActivity。

# view1D is from https://stackoverflow.com/a/45313353/
def view1D(a, b): # a, b are arrays
    a = np.ascontiguousarray(a)
    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(void_dt).ravel(),  b.view(void_dt).ravel()

def pattern_index_view1D(all_data, search_data):
    a = strided_app(np.asarray(all_data), L=len(search_data), S=1)
    a0v, b0v = view1D(np.asarray(a), np.asarray(search_data))
    return np.flatnonzero(np.in1d(a0v, b0v)) 

out = np.squeeze(pattern_index_view1D(l, m)[:,None] + np.arange(len(m)))

这是分类器,与官方代码相同。

java.lang.IllegalArgumentException: input must be 4-dimensional[1,1,299,299,3]
                                                                         [[Node: ResizeBilinear = ResizeBilinear[T=DT_FLOAT, align_corners=false, _device="/job:localhost/replica:0/task:0/cpu:0"](ExpandDims, ResizeBilinear/size)]]
                                                                         at org.tensorflow.Session.run(Native Method)
                                                                         at org.tensorflow.Session.access$100(Session.java:48)
                                                                         at org.tensorflow.Session$Runner.runHelper(Session.java:295)
                                                                         at org.tensorflow.Session$Runner.run(Session.java:245)
                                                                         at org.tensorflow.contrib.android.TensorFlowInferenceInterface.run(TensorFlowInferenceInterface.java:144)
                                                                         at com.example.yuuuuu.tensorTest.TensorFlowImageClassifier.recognizeImage(TensorFlowImageClassifier.java:119)
                                                                         at com.example.yuuuuu.tensorTest.MainActivity.runTensor(MainActivity.java:69)
                                                                         at com.example.yuuuuu.tensorTest.MainActivity$1.onClick(MainActivity.java:42)
                                                                         at android.view.View.performClick(View.java:6205)
                                                                         at android.widget.TextView.performClick(TextView.java:11103)
                                                                         at android.view.View$PerformClick.run(View.java:23653)
                                                                         at android.os.Handler.handleCallback(Handler.java:751)
                                                                         at android.os.Handler.dispatchMessage(Handler.java:95)
                                                                         at android.os.Looper.loop(Looper.java:154)
                                                                         at android.app.ActivityThread.main(ActivityThread.java:6682)
                                                                         at java.lang.reflect.Method.invoke(Native Method)
                                                                         at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:1520)
                                                                         at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1410)

最后是TensorFlowImageClassifier,与官方一样。

public class MainActivity extends AppCompatActivity {

private static final String MODEL_FILE = "file:///android_asset/optimized_graph.pb";
private static final String LABEL_FILE = "file:///android_asset/output_labels.txt";
private static final String INPUT_NAME = "Cast";
private static final String OUTPUT_NAME = "final_result";
private static final int INPUT_SIZE = 299;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;

private Classifier classifier;
private TextView textView;
private ImageView img;
private Button button;

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    textView = (TextView)findViewById(R.id.textView);
    button = (Button)findViewById(R.id.btn);
    img = (ImageView)findViewById(R.id.img);

    button.setOnClickListener(new View.OnClickListener(){
        public void onClick(View v){
            runTensor();
        }
    });

    initTensor();
}

public void initTensor(){
    classifier = TensorFlowImageClassifier.create(
            getAssets(),
            MODEL_FILE,
            LABEL_FILE,
            INPUT_SIZE,
            IMAGE_MEAN,
            IMAGE_STD,
            INPUT_NAME,
            OUTPUT_NAME
    );
}

public void runTensor(){
    Bitmap bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test);
    bitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, false);

    img = (ImageView)findViewById(R.id.img);
    img.setImageBitmap(bitmap);

    final List<Classifier.Recognition> results = classifier.recognizeImage(bitmap);
    textView.setText(results.toString());
}

protected void onDestroy(){
    super.onDestroy();
    classifier.close();
}

}

如果您知道如何修复这些代码,请告诉我如何操作。

2 个答案:

答案 0 :(得分:1)

  

java.lang.IllegalArgumentException:输入必须是   4维[1,1,299,299,3]

错误消息解释了问题:您不小心传递了5项数组而不是4项数组。即你应该传递像[1,299,299,1]而不是[1,1,299,299,3]这样的东西。

很难从你的问题中判断出你实际改变了哪些代码。如果您可以将更改作为单个Git提交进行,那么有人可能更容易识别导致问题的更改?

您可以尝试在TensorBoard中查看TensorFlow模型,以检查输入和输出节点,检查它们是否与您配置的值匹配:
https://medium.com/@daj/how-to-inspect-a-pre-trained-tensorflow-model-5fd2ee79ced0

答案 1 :(得分:0)

好吧,当我使用本机库时,我注意到他们通常不会从资产中获取文件,您需要将其复制到不可访问的文件存储路径并将绝对路径传递给库。

您的错误可能来自加载资源。