我已经用keras创建了图像分类器,然后将模型保存为.pb
格式。该模型很简单,它只能检测猫和狗。
我想在android中使用该保存的模型。
这是我到目前为止尝试过的。我正在使用以下依赖项
implementation 'org.tensorflow:tensorflow-android:1.13.1'
然后将图像转换为字节数组。
private static final String MODEL_FILE = "file:///android_asset/tf_model.pb";
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
inferenceInterface.feed("conv2d_4_input", readBytes(ims), 1,64,64);
inferenceInterface.fetch("output_node/Softmax",output);
inferenceInterface.run(new String[]{"output_node/Softmax"});
但是出现以下错误。
Caused by: java.lang.IllegalArgumentException: buffer with 136886 elements is not compatible with a Tensor with shape [1, 64, 64]
不确定,我现在该怎么办。但是,我可以使用以下代码在python中对图像进行分类。
import numpy as np
from keras.preprocessing import image
test_image = image.load_img('dataset/single_prediction/cat_or_dog_1.jpg', target_size = (64, 64))
test_image = image.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis = 0)
result = classifier.predict(test_image)
training_set.class_indices
if result[0][0] == 1:
prediction = 'dog'
else:
prediction = 'cat'