我正在寻找一些使用Java从Tensorflow模型中提取权重和偏差的示例代码。有人知道这是否可能吗?使用python API会更好吗?
谢谢 -比尔
答案 0 :(得分:0)
以下内容对我有用:
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Arrays;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Session.Runner;
import org.tensorflow.Tensor;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
private static String graphPath = "models/export/exporter/1390282892";
public class PrintWeights {
public static void main(String[] args) throws Exception {
try (SavedModelBundle b = SavedModelBundle.load(graphPath, "serve")) {
Graph graph = b.graph();
GraphDef graphDef = GraphDef.parseFrom(graph.toGraphDef());
try (Session session = b.session()) {
Runner runner = session.runner();
Tensor<Float> tensor = (Tensor<Float>)runner.fetch("dnn/hiddenlayer_0/kernel").run().get(0);
float[][] weights = new float[(int)tensor.shape()[0]][(int)tensor.shape()[1]];
tensor.copyTo(weights);
System.out.println(Arrays.deepToString(weights).replace("], ", "]\n"));
}
}
}
}