我有一个基本的Android TensorFlowInference
示例,可以在单个线程中正常运行。
public class InferenceExample {
private static final String MODEL_FILE = "file:///android_asset/model.pb";
private static final String INPUT_NODE = "intput_node0";
private static final String OUTPUT_NODE = "output_node0";
private static final int[] INPUT_SIZE = {1, 8000, 1};
public static final int CHUNK_SIZE = 8000;
public static final int STRIDE = 4;
private static final int NUM_OUTPUT_STATES = 5;
private static TensorFlowInferenceInterface inferenceInterface;
public InferenceExample(final Context context) {
inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);
}
public float[] run(float[] data) {
float[] res = new float[CHUNK_SIZE / STRIDE * NUM_OUTPUT_STATES];
inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]);
inferenceInterface.run(new String[]{OUTPUT_NODE});
inferenceInterface.fetch(OUTPUT_NODE, res);
return res;
}
}
根据以下示例在java.lang.ArrayIndexOutOfBoundsException
中运行时,示例崩溃时会出现各种异常,包括java.lang.NullPointerException
和ThreadPool
,所以我猜这不是线程安全的。
InferenceExample inference = new InferenceExample(context);
ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);
Collection<Future<?>> futures = new LinkedList<Future<?>>();
for (int i = 1; i <= 100; i++) {
Future<?> result = executor.submit(new Runnable() {
public void run() {
inference.call(randomData);
}
});
futures.add(result);
}
for (Future<?> future:futures) {
try { future.get(); }
catch(ExecutionException | InterruptedException e) {
Log.e("TF", e.getMessage());
}
}
是否可以利用TensorFlowInferenceInterface
?
答案 0 :(得分:1)
为了确保InferenceExample
主题安全,我更改了TensorFlowInferenceInterface
static
,并run
方法synchronized
:
private TensorFlowInferenceInterface inferenceInterface;
public InferenceExample(final Context context) {
inferenceInterface = new TensorFlowInferenceInterface(assets, model);
}
public synchronized float[] run(float[] data) { ... }
然后我在InterferenceExample
之间循环显示numThreads
个实例的列表。
for (int i = 1; i <= 100; i++) {
final int id = i % numThreads;
Future<?> result = executor.submit(new Runnable() {
public void run() {
list.get(id).run(data);
}
});
futures.add(result);
}
但这确实提高了性能
在8核设备上,此峰值为numThreads
的2,仅显示Android Studio Monitor中约50%的CPU使用率。
答案 1 :(得分:0)