我训练了一个模型来使用 MNIST 数据集识别数字。该模型已使用 TensorFlow 和 Keras 在 Python 中进行训练,输出保存到我命名为“sample_mnist.h5”的 HDF5 文件中。
我想将经过训练的模型从 HDF5 文件加载到 Rust 中以进行预测。
在 Python 中,我可以从 HDF5 生成模型并使用代码进行预测:
model = keras.models.load_model("./sample_mnist.h5")
model.precict(test_input) # assumes test_input is the correct input type for the model
这个 Python 片段的 Rust 等价物是什么?
答案 0 :(得分:0)
首先,您需要将模型保存为 .pb 格式,而不是 .hdf5,以便将其移植到 Rust,因为这种格式可以节省 >关于在 Python 之外重建模型所需的模型执行图的所有内容。 TensorFlow Rust 存储库上有来自用户 pull request 的公开 justnoxx,展示了如何为简单模型执行此操作。要点是在 Python 中给出了一些模型......
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
classifier = Sequential()
classifier.add(Dense(5, activation='relu', name="test_in", input_dim=5)) # Named input
classifier.add(Dense(5, activation='relu'))
classifier.add(Dense(1, activation='sigmoid', name="test_out")) # Named output
classifier.compile(optimizer ='adam', loss='binary_crossentropy', metrics=['accuracy'])
classifier.fit([[0.1, 0.2, 0.3, 0.4, 0.5]], [[1]], batch_size=1, epochs=1);
classifier.save('examples/keras_single_input_saved_model', save_format='tf')
以及我们命名的输入“test_in”和输出“test_out”以及它们的预期大小,我们可以在 Rust 中应用保存的模型......
use tensorflow::{Graph, SavedModelBundle, SessionOptions, SessionRunArgs, Tensor};
fn main() {
// In this file test_in_input is being used while in the python script,
// that generates the saved model from Keras model it has a name "test_in".
// For multiple inputs _input is not being appended to signature input parameter name.
let signature_input_parameter_name = "test_in_input";
let signature_output_parameter_name = "test_out";
// Initialize save_dir, input tensor, and an empty graph
let save_dir =
"examples/keras_single_input_saved_model";
let tensor: Tensor<f32> = Tensor::new(&[1, 5])
.with_values(&[0.1, 0.2, 0.3, 0.4, 0.5])
.expect("Can't create tensor");
let mut graph = Graph::new();
// Load saved model bundle (session state + meta_graph data)
let bundle =
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, save_dir)
.expect("Can't load saved model");
// Get the session from the loaded model bundle
let session = &bundle.session;
// Get signature metadata from the model bundle
let signature = bundle
.meta_graph_def()
.get_signature("serving_default")
.unwrap();
// Get input/output info
let input_info = signature.get_input(signature_input_parameter_name).unwrap();
let output_info = signature
.get_output(signature_output_parameter_name)
.unwrap();
// Get input/output ops from graph
let input_op = graph
.operation_by_name_required(&input_info.name().name)
.unwrap();
let output_op = graph
.operation_by_name_required(&output_info.name().name)
.unwrap();
// Manages inputs and outputs for the execution of the graph
let mut args = SessionRunArgs::new();
args.add_feed(&input_op, 0, &tensor); // Add any inputs
let out = args.request_fetch(&output_op, 0); // Request outputs
// Run model
session.run(&mut args) // Pass to session to run
.expect("Error occurred during calculations");
// Fetch outputs after graph execution
let out_res: f32 = args.fetch(out).unwrap()[0];
println!("Results: {:?}", out_res);
}