在java中加载DNNClassifier

时间:2018-02-08 20:01:24

标签: java python tensorflow

我在python中学习了tensorflow中的DNNClassifier。我有32个浮点作为输入,我有4个输出类。这是程序:

  training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=GESTURE_TRAINING,
      target_dtype=np.int,
      features_dtype=np.float32)
  test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=GESTURE_TEST,
      target_dtype=np.int,
      features_dtype=np.float32)

  # Specify that all features have real-value data
  feature_columns = [tf.feature_column.numeric_column("x", shape=[32])]

  # Build 3 layer DNN with 10, 20, 10 units respectively.
  classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                          hidden_units=[10, 20, 10],
                                          n_classes=4,
                                          model_dir="./model/")




# Define the training inputs
  train_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": np.array(training_set.data)},
      y=np.array(training_set.target),
      num_epochs=None,
      shuffle=True)

  # Train model.

  classifier.train(input_fn=train_input_fn, steps=400)

这是虹膜修改的例子,它生成.pb文件和这个文件:

checkpoint
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index

预测,使用该模型我使用该函数加载模型:   predictor = tf.contrib.predictor.from_saved_model(exported_pa​​th) 使用“exported_pa​​th”指向pb文件的路径。

我的问题是如何在java中加载我的模型。在python中,我尝试使用它来加载我的模型:

with tf.gfile.GFile(filename, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

但我得到了这个错误:

graph_def.ParseFromString(f.read())
google.protobuf.message.DecodeError: Error parsing message

问题是我在java中找不到tf.contrib.predictor.from_saved_model的等价物。

1 个答案:

答案 0 :(得分:1)

您想要执行"Using SavedModel with Estimators"中建议的操作,这意味着您可以使用以下内容在Python程序中导出:

# Input to the classifier is a batch of 32-element vectors
inputs = {"x" : tf.placeholder(tf.float32, shape=[None, 32])}
classifier.export_savedmodel("./saved_model", tf.estimator.export.build_raw_serving_input_receiver_fn(inputs))

然后使用SavedModelBundle.load()

加载并执行Java

例如,这是用于训练模型然后以SavedModel格式导出模型的Python代码:

import tensorflow as tf
import numpy as np

feature_columns = [tf.feature_column.numeric_column("x", shape=[32])]

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                        hidden_units=[10, 20, 10],
                                        n_classes=4,
                                        model_dir="./model/")

# Random inputs and outputs here, probably want them from the file
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": np.random.rand(10, 32)},
    y=np.random.randint(4, size=10),
    num_epochs=None,
    shuffle=True) 
classifier.train(input_fn=train_input_fn, steps=400)

inputs = {"x" : tf.placeholder(tf.float32, shape=[None, 32])}
classifier.export_savedmodel("./saved_model", tf.estimator.export.build_raw_serving_input_receiver_fn(inputs))

这是相应的Java代码,用于加载训练模型并对其执行预测。

try (SavedModelBundle model = SavedModelBundle.load("./saved_model/1518198088", "serve")) {
  // A batch of inputs. In real life of course you'd set each row to the actual input you're
  // interested in.
  final int BATCH_SIZE = 1;
  float[][] in = new float[BATCH_SIZE][32];
  try (Tensor<Float> tInput = Tensors.create(in);
      Tensor<Float> tProbs =
          model.session().runner()
              .feed("Placeholder", tInput)
              .fetch("dnn/head/predictions/probabilities")
              .run().get(0).expect(Float.class)) {
    float[][] probabilities = tProbs.copyTo(new float[BATCH_SIZE][4]);
    System.out.print("Predicted class probabilities: ");
    for (int i = 0; i < probabilities.length; ++i) {
      System.out.println(String.format("-- Input #%d", i));
      for (int j = 0; j < probabilities[i].length; ++j) {
        System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
      }
    }
  }
}

您可能还会发现解释TensorFlow模型格式有用的slides(链接到tensorflow/models存储库。提供给feedfetch的张量的名称可以可以从以下地址获得:

  1. 使用saved_model_cli show --dir ./saved_model/1518198088 --all

  2. 的命令行
  3. 使用Java API解析SavedModelBundle中的模型签名信息。请参阅tensorflow/models/samples/languages/java和/或this code sample

  4. 上的幻灯片

    在Python中创建model_dir对象时提供的Estimator在几个文件中写出模型 - 以人类可读格式写出的协议缓冲区形式的计算图形( graph.pbtxt)和一些包含训练权重的二进制文件。您可以直接从Java中读取它们,但这意味着您必须管理解析图形,然后通过运行“从检查点恢复”操作将权重初始化为训练值。

    SavedModel格式将所有这些打包在一起,因此在Java中使用SavedModelBundle.load()会为您完成所有这些。

    希望有所帮助。