我已经成功地训练了数字准直器。现在,我试图在android中使用它。我从未使用过tensorflow,因此我遵循了一堆教程并说到需要在android app中使用我创建的.pb文件的地步。我正在尝试将其加载,但是它需要inputName和outputName。我不知道那会是什么。从python脚本中,我认为outputName等于final_result,但对于其余我不知道。这就是我在Android
中所拥有的 mClassifiers.add(
TensorFlowClassifier.create(
context.getAssets(),
"?????", // <- what goes here ?
"clasifier.pb",
"labels.txt",
100,
"????", // <- what goes here ?
"???", // <- what goes here ?
true)
);
import android.content.res.AssetManager;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class TensorFlowClassifier implements Classifier {
// Only returns if at least this confidence
//must be a classification percetnage greater than this
private static final float THRESHOLD = 0.1f;
private TensorFlowInferenceInterface tfHelper;
private String name;
private String inputName;
private String outputName;
private int inputSize;
private boolean feedKeepProb;
private List<String> labels;
private float[] output;
private String[] outputNames;
//given a saved drawn model, lets read all the classification labels that are
//stored and write them to our in memory labels list
private static List<String> readLabels(AssetManager am, String fileName) throws IOException {
List<String> labels = new ArrayList<>();
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(am.open(fileName)));
String line;
while ((line = br.readLine()) != null) {
labels.add(line);
}
} catch (Exception e) {
} finally {
if (br != null) {
br.close();
}
}
return labels;
}
//given a model, its label file, and its metadata
//fill out a classifier object with all the necessary
//metadata including output prediction
public static TensorFlowClassifier create(AssetManager assetManager,
String name,
String modelPath,
String labelFile,
int inputSize,
String inputName,
String outputName,
boolean feedKeepProb) throws IOException {
//intialize a classifier
TensorFlowClassifier c = new TensorFlowClassifier();
//store its name, input and output labels
c.name = name;
c.inputName = inputName;
c.outputName = outputName;
//read labels for label file
c.labels = readLabels(assetManager, labelFile);
//set its model path and where the raw asset files are
c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
int numClasses = 10;
//how big is the input?
c.inputSize = inputSize;
// Pre-allocate buffer.
c.outputNames = new String[] { outputName };
c.outputName = outputName;
c.output = new float[numClasses];
c.feedKeepProb = feedKeepProb;
return c;
}
@Override
public String name() {
return name;
}
@Override
public Classification recognize(final float[] pixels, final int width, final int height) {
//using the interface
//give it the input name, raw pixels from the drawing,
//input size
tfHelper.feed(inputName, pixels, 1, width, height, 1);
//probabilities
if (feedKeepProb) {
tfHelper.feed("keep_prob", new float[] { 1 });
}
//get the possible outputs
tfHelper.run(outputNames);
//get the output
tfHelper.fetch(outputName, output);
// Find the best classification
//for each output prediction
//if its above the threshold for accuracy we predefined
//write it out to the view
Classification ans = new Classification();
for (int i = 0; i < output.length; ++i) {
/*System.out.println(output[i]);
System.out.println(labels.get(i));*/
if (!labels.get(i).equals("0") && output[i] > THRESHOLD && output[i] > ans.getConf()) {
ans.update(output[i], labels.get(i));
}
}
return ans;
}
}
在这里可以找到python脚本,因为我无法包含它 https://github.com/MicrocontrollersAndMore/TensorFlow_Tut_2_Classification_Walk-through/blob/master/retrain.py
答案 0 :(得分:0)
TensorFlowInferenceInterface tensorflow = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
Iterator<Operation> operationIterator = tensorflow.graph().operations();
while (operationIterator.hasNext()){
Operation operation = operationIterator.next();
System.out.print(operation.name());
}
在加载模型文件后尝试执行此操作以查看图层名称。希望对您有所帮助!