我在python中学习了tensorflow中的DNNClassifier。我有32个浮点作为输入,我有4个输出类。这是程序:
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=GESTURE_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=GESTURE_TEST,
target_dtype=np.int,
features_dtype=np.float32)
# Specify that all features have real-value data
feature_columns = [tf.feature_column.numeric_column("x", shape=[32])]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=4,
model_dir="./model/")
# Define the training inputs
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(training_set.data)},
y=np.array(training_set.target),
num_epochs=None,
shuffle=True)
# Train model.
classifier.train(input_fn=train_input_fn, steps=400)
这是虹膜修改的例子,它生成.pb文件和这个文件:
checkpoint
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
预测,使用该模型我使用该函数加载模型: predictor = tf.contrib.predictor.from_saved_model(exported_path) 使用“exported_path”指向pb文件的路径。
我的问题是如何在java中加载我的模型。在python中,我尝试使用它来加载我的模型:
with tf.gfile.GFile(filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
但我得到了这个错误:
graph_def.ParseFromString(f.read())
google.protobuf.message.DecodeError: Error parsing message
问题是我在java中找不到tf.contrib.predictor.from_saved_model的等价物。
答案 0 :(得分:1)
您想要执行"Using SavedModel with Estimators"中建议的操作,这意味着您可以使用以下内容在Python程序中导出:
# Input to the classifier is a batch of 32-element vectors
inputs = {"x" : tf.placeholder(tf.float32, shape=[None, 32])}
classifier.export_savedmodel("./saved_model", tf.estimator.export.build_raw_serving_input_receiver_fn(inputs))
加载并执行Java
例如,这是用于训练模型然后以SavedModel格式导出模型的Python代码:
import tensorflow as tf
import numpy as np
feature_columns = [tf.feature_column.numeric_column("x", shape=[32])]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=4,
model_dir="./model/")
# Random inputs and outputs here, probably want them from the file
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.random.rand(10, 32)},
y=np.random.randint(4, size=10),
num_epochs=None,
shuffle=True)
classifier.train(input_fn=train_input_fn, steps=400)
inputs = {"x" : tf.placeholder(tf.float32, shape=[None, 32])}
classifier.export_savedmodel("./saved_model", tf.estimator.export.build_raw_serving_input_receiver_fn(inputs))
这是相应的Java代码,用于加载训练模型并对其执行预测。
try (SavedModelBundle model = SavedModelBundle.load("./saved_model/1518198088", "serve")) {
// A batch of inputs. In real life of course you'd set each row to the actual input you're
// interested in.
final int BATCH_SIZE = 1;
float[][] in = new float[BATCH_SIZE][32];
try (Tensor<Float> tInput = Tensors.create(in);
Tensor<Float> tProbs =
model.session().runner()
.feed("Placeholder", tInput)
.fetch("dnn/head/predictions/probabilities")
.run().get(0).expect(Float.class)) {
float[][] probabilities = tProbs.copyTo(new float[BATCH_SIZE][4]);
System.out.print("Predicted class probabilities: ");
for (int i = 0; i < probabilities.length; ++i) {
System.out.println(String.format("-- Input #%d", i));
for (int j = 0; j < probabilities[i].length; ++j) {
System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
}
}
}
}
您可能还会发现解释TensorFlow模型格式有用的slides(链接到tensorflow/models存储库。提供给feed
和fetch
的张量的名称可以可以从以下地址获得:
使用saved_model_cli show --dir ./saved_model/1518198088 --all
或
使用Java API解析SavedModelBundle
中的模型签名信息。请参阅tensorflow/models/samples/languages/java和/或this code sample
在Python中创建model_dir
对象时提供的Estimator
在几个文件中写出模型 - 以人类可读格式写出的协议缓冲区形式的计算图形( graph.pbtxt
)和一些包含训练权重的二进制文件。您可以直接从Java中读取它们,但这意味着您必须管理解析图形,然后通过运行“从检查点恢复”操作将权重初始化为训练值。
SavedModel格式将所有这些打包在一起,因此在Java中使用SavedModelBundle.load()
会为您完成所有这些。
希望有所帮助。