从用Java的DNNRegressor生成的模型读取权重

时间:2018-08-24 02:52:32

标签: tensorflow machine-learning

我正在寻找一些使用Java从Tensorflow模型中提取权重和偏差的示例代码。有人知道这是否可能吗?使用python API会更好吗?

谢谢 -比尔

1 个答案:

答案 0 :(得分:0)

以下内容对我有用:

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Arrays;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Session.Runner;
import org.tensorflow.Tensor;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;


private static String graphPath = "models/export/exporter/1390282892";

public class PrintWeights {

    public static void main(String[] args) throws Exception {

        try (SavedModelBundle b = SavedModelBundle.load(graphPath, "serve")) {

            Graph graph = b.graph();

            GraphDef graphDef = GraphDef.parseFrom(graph.toGraphDef());

            try (Session session = b.session()) {

                Runner runner = session.runner();

                Tensor<Float> tensor = (Tensor<Float>)runner.fetch("dnn/hiddenlayer_0/kernel").run().get(0);

                float[][] weights = new float[(int)tensor.shape()[0]][(int)tensor.shape()[1]];

                tensor.copyTo(weights);

                System.out.println(Arrays.deepToString(weights).replace("], ", "]\n"));

            }
        }
    }

}