如何在Android

时间:2019-05-24 04:30:32

标签: android tensorflow deep-learning conv-neural-network

我已经用keras创建了图像分类器,然后将模型保存为.pb格式。该模型很简单,它只能检测猫和狗。

我想在android中使用该保存的模型。

这是我到目前为止尝试过的。我正在使用以下依赖项

implementation 'org.tensorflow:tensorflow-android:1.13.1'

然后将图像转换为字节数组。

private static final String MODEL_FILE = "file:///android_asset/tf_model.pb";

TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
inferenceInterface.feed("conv2d_4_input", readBytes(ims),  1,64,64);
inferenceInterface.fetch("output_node/Softmax",output);
inferenceInterface.run(new String[]{"output_node/Softmax"});

但是出现以下错误。

   Caused by: java.lang.IllegalArgumentException: buffer with 136886 elements is not compatible with a Tensor with shape [1, 64, 64]

不确定,我现在该怎么办。但是,我可以使用以下代码在python中对图像进行分类。

import numpy as np
from keras.preprocessing import image
test_image = image.load_img('dataset/single_prediction/cat_or_dog_1.jpg', target_size = (64, 64))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis = 0)
result = classifier.predict(test_image)
training_set.class_indices
if result[0][0] == 1:
    prediction = 'dog'
else:
    prediction = 'cat'


0 个答案:

没有答案