Tensorflow:如何在Java中使用经过python训练的语音识别模型

时间:2018-08-20 11:52:05

标签: java python tensorflow

我有一个遵循this article.的python训练的张量流模型,训练后我生成了冻结图。现在,我需要使用此图并在 JAVA上生成识别基于应用程序。 为此,我正在看下面的example。但是我不明白是如何收集我的输出。我知道我需要为图表提供3个输入。

从官方教程中给出的示例中,我已经阅读了基于python的代码。

def run_graph(wav_data, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  """Runs the audio data through the graph and prints predictions."""
  with tf.Session() as sess:
    # Feed the audio data as input to the graph.
    #   predictions  will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
    predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})

    # Sort to show labels in order of confidence
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = labels[node_id]
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

    return 0

有人可以帮助我理解tensorflow Java API吗?

1 个答案:

答案 0 :(得分:1)

上面列出的Python代码的字面翻译如下:

public class FirebaseDataReceiver extends BroadcastReceiver {

Context context;
PendingIntent pendingIntent;

public void onReceive(Context context, Intent intent) {
    this.context = context;
    Bundle dataBundle = intent.getExtras();
    String title = "";
    String body = "";
    String type = "";
    String objectId = "";

    if (dataBundle != null) {
        type = dataBundle.getString("type");
        objectId = dataBundle.getString("objectId");
        title = NotificationUtils.getNotificationTitle(context, dataBundle);
        body = NotificationUtils.getNotificationBody(context, dataBundle);
    }

    Intent newIntent = new Intent(context, TutorialActivity_.class);
    newIntent.putExtra("target", "notification");
    newIntent.putExtra("type", type);
    newIntent.putExtra("objectId", objectId);
    newIntent.setFlags(Intent.FLAG_ACTIVITY_CLEAR_TOP
            | Intent.FLAG_ACTIVITY_SINGLE_TOP);

    pendingIntent = PendingIntent.getActivity(context,
            0,
            newIntent,
            PendingIntent.FLAG_UPDATE_CURRENT);

    Notification notification = new Notification.Builder(context)
            .setContentTitle(title)
            .setContentText(body)
            .setPriority(Notification.PRIORITY_HIGH)
            .setDefaults(Notification.DEFAULT_ALL)
            .setContentIntent(pendingIntent)
            .setAutoCancel(true)
            .setSmallIcon(R.drawable.splash_mini)
            .build();

    deleteLastNotification();
    NotificationManagerCompat.from(context).notify(0, notification);
}
}

返回的public static float[][] getPredictions(Session sess, byte[] wavData, String inputLayerName, String outputLayerName) { try (Tensor<String> wavDataTensor = Tensors.create(wavData); Tensor<Float> predictionsTensor = sess.runner() .feed(inputLayerName, wavDataTensor) .fetch(outputLayerName) .run() .get(0) .expect(Float.class)) { float[][] predictions = new float[(int)predictionsTensor.shape(0)][(int)predictionsTensor.shape(1)]; predictionsTensor.copyTo(predictions); return predictions; } } 数组将具有每个预测的“ confidence”值,并且您必须运行逻辑以在其上计算“ top K”,类似于Python代码的使用方式numpy(predictions)来完成.argsort()返回的内容。

粗略地阅读了教程页面和代码,看来sess.run()将有1行12列(每个热门单词一个)。我是从以下Python代码获得的:

predictions

希望有帮助。