我按照https://www.tensorflow.org/lite/performance/gpu中的步骤进行操作,并将gpu添加到tflite解释器中。一个名为tensorflow / lite / jave / demo的android项目提供的mobilenet模型可以在Pixel 2上使用该项目。但是,当将其与以下代码一起放入另一个项目时。出现错误:
TfLiteGpuDelegate Invoke: Write to buffer failed. Source data is larger than buffer
Node number 31 (TfLiteGpuDelegateV2) failed to invoke.
imgData,outputClasses已经被检查并且与tensorflow / lite / jave / demo项目中的大小相同。
TfLite和TfLiteGpu的版本也与tensorflow / lite / jave / demo项目中的版本相同。
import org.tensorflow.lite.TensorFlowLite;
import org.tensorflow.lite.gpu.GpuDelegate;
/**
* Wrapper for frozen detection models trained using the Tensorflow Object Detection API:
* github.com/tensorflow/models/tree/master/research/object_detection
*/
public class TFLiteObjectDetectionAPIModel implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results.
private static final int NUM_DETECTIONS = 1001; // hand landmark ,21, mobilenet, 1001
// Float model
private static final float IMAGE_MEAN = 128.0f;
private static final float IMAGE_STD = 128.0f;
private boolean isModelQuantized;
private int inputSize;
private Vector<String> labels = new Vector<String>();
private int[] intValues;
private float[][][] outputLocations;
private float[][] outputClasses;
private float[][] outputScores;
private float[] numDetections;
private GpuDelegate gpuDelegate = new GpuDelegate();
private final Interpreter.Options options = (new Interpreter.Options()).addDelegate(gpuDelegate);
private ByteBuffer imgData;
private Interpreter tfLite;
private ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
private TFLiteObjectDetectionAPIModel() {}
/** Memory-map the model file in Assets. */
private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
throws IOException {
FileInputStream inputStream = new FileInputStream(new File("/sdcard/sunny/data/"+modelFilename+".tflite"));
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileChannel.position();
long declaredLength = fileChannel.size();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param inputSize The size of image input
* @param isQuantized Boolean representing model is quantized or not
*/
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
final int inputSize,
final boolean isQuantized)
throws IOException {
final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel();
InputStream labelsInput = null;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
LOGGER.w(line);
d.labels.add(line);
}
br.close();
d.inputSize = inputSize;
try {
d.tfLite = new Interpreter(loadModelFile(assetManager, modelFilename),d.options);
} catch (Exception e) {
throw new RuntimeException(e);
}
d.isModelQuantized = isQuantized;
// Pre-allocate buffers.
int numBytesPerChannel;
if (isQuantized) {
numBytesPerChannel = 1; // Quantized
} else {
numBytesPerChannel = 4; // Floating point
}
d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * 4);
d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
d.outputClasses = new float[1][NUM_DETECTIONS];
return d;
}
@Override
public List<Recognition> processImage(final AssetManager assetManager, Classifier.Recognition.inputFormat imageFormat, int[] intValues){
Trace.beginSection("preprocessBitmap");
imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
int pixelValue = intValues[i * inputSize + j];
if (isModelQuantized) {
// Quantized model
imgData.put((byte) ((pixelValue >> 16) & 0xFF));
imgData.put((byte) ((pixelValue >> 8) & 0xFF));
imgData.put((byte) (pixelValue & 0xFF));
} else { // Float model
imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
outputClasses = new float[1][NUM_DETECTIONS];
Trace.endSection();
// Run the inference call.
Trace.beginSection("run");
tfLite.run(imgData, outputClasses);
Trace.endSection();
return recognitions;
}
@Override
public void enableStatLogging(final boolean logStats) {}
@Override
public String getStatString() {
return "";
}
@Override
public void close() {
tfLite.close();
tfLite = null;
recognitions.clear();
recognitions=null;
gpuDelegate.close();
gpuDelegate = null;
}