我有一个遵循this article.的python训练的张量流模型,训练后我生成了冻结图。现在,我需要使用此图并在 JAVA上生成识别基于应用程序。 为此,我正在看下面的example。但是我不明白是如何收集我的输出。我知道我需要为图表提供3个输入。
从官方教程中给出的示例中,我已经阅读了基于python的代码。
def run_graph(wav_data, labels, input_layer_name, output_layer_name,
num_top_predictions):
"""Runs the audio data through the graph and prints predictions."""
with tf.Session() as sess:
# Feed the audio data as input to the graph.
# predictions will contain a two-dimensional array, where one
# dimension represents the input image count, and the other has
# predictions per class
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
# Sort to show labels in order of confidence
top_k = predictions.argsort()[-num_top_predictions:][::-1]
for node_id in top_k:
human_string = labels[node_id]
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
return 0
有人可以帮助我理解tensorflow Java API吗?
答案 0 :(得分:1)
上面列出的Python代码的字面翻译如下:
public class FirebaseDataReceiver extends BroadcastReceiver {
Context context;
PendingIntent pendingIntent;
public void onReceive(Context context, Intent intent) {
this.context = context;
Bundle dataBundle = intent.getExtras();
String title = "";
String body = "";
String type = "";
String objectId = "";
if (dataBundle != null) {
type = dataBundle.getString("type");
objectId = dataBundle.getString("objectId");
title = NotificationUtils.getNotificationTitle(context, dataBundle);
body = NotificationUtils.getNotificationBody(context, dataBundle);
}
Intent newIntent = new Intent(context, TutorialActivity_.class);
newIntent.putExtra("target", "notification");
newIntent.putExtra("type", type);
newIntent.putExtra("objectId", objectId);
newIntent.setFlags(Intent.FLAG_ACTIVITY_CLEAR_TOP
| Intent.FLAG_ACTIVITY_SINGLE_TOP);
pendingIntent = PendingIntent.getActivity(context,
0,
newIntent,
PendingIntent.FLAG_UPDATE_CURRENT);
Notification notification = new Notification.Builder(context)
.setContentTitle(title)
.setContentText(body)
.setPriority(Notification.PRIORITY_HIGH)
.setDefaults(Notification.DEFAULT_ALL)
.setContentIntent(pendingIntent)
.setAutoCancel(true)
.setSmallIcon(R.drawable.splash_mini)
.build();
deleteLastNotification();
NotificationManagerCompat.from(context).notify(0, notification);
}
}
返回的public static float[][] getPredictions(Session sess, byte[] wavData, String inputLayerName, String outputLayerName) {
try (Tensor<String> wavDataTensor = Tensors.create(wavData);
Tensor<Float> predictionsTensor = sess.runner()
.feed(inputLayerName, wavDataTensor)
.fetch(outputLayerName)
.run()
.get(0)
.expect(Float.class)) {
float[][] predictions = new float[(int)predictionsTensor.shape(0)][(int)predictionsTensor.shape(1)];
predictionsTensor.copyTo(predictions);
return predictions;
}
}
数组将具有每个预测的“ confidence”值,并且您必须运行逻辑以在其上计算“ top K”,类似于Python代码的使用方式numpy(predictions
)来完成.argsort()
返回的内容。
粗略地阅读了教程页面和代码,看来sess.run()
将有1行12列(每个热门单词一个)。我是从以下Python代码获得的:
predictions
希望有帮助。