点燃:如何保存和重新加载经过训练的模型

时间:2019-01-31 06:12:02

标签: java intellij-idea ignite

以下是我用来训练模型的一段代码。之后,除了FileExporter类之外,如何以及在哪里可以保存我的模型并读取回去?它仅在文件中还是我可以将其存储在缓存中并进行访问?

IgniteCache<Integer, double[]> cache = ignite.getOrCreateCache("MLData_IRIS");

// extracting   sepal length, sepal width, petal length, petal width
IgniteBiFunction<Integer, double[], Vector> featureExtractor = new RangeExtractor(1, 5);
IgniteBiFunction<Integer, double[], Double> labelExtractor = new PointExtractor(0);

System.out.println(">>> Create new training dataset splitter object.");
TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>()
    .split(0.5, 0.5);

IgniteBiPredicate<Integer, double[]> testData = split.getTestFilter();
IgniteBiPredicate<Integer, double[]> trainData = split.getTrainFilter();

// Set up the trainer
KMeansTrainer trainer = new KMeansTrainer()
    .withDistance(new EuclideanDistance())  //other metrics are HammingDistance, ManhattanDistance
    .withAmountOfClusters(3) // number clusters want to create
    .withMaxIterations(100)
    .withEpsilon(1.0E-4D)
    .withSeed(1234L);

long t1 = System.currentTimeMillis();

KMeansModel mdl = trainer.fit(
    ignite,
    cache,
    trainData,
    featureExtractor,
    labelExtractor
);

long t2 = System.currentTimeMillis();
System.out.println("time taken to build the model : " + (t2 - t1) + " ms");

System.out.println(">>> --------------------------------------------");
System.out.println(">>> trained model: " + mdl.toString(true));

1 个答案:

答案 0 :(得分:2)

目前,Ignite仅具有这种机制-FileExporter。

但是,对于2.8版,我们已经实现了模型存储。

保存模型的示例:

ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
        storage.mkdirs("/");
        storage.putFile("/my_model", serializedMdl);

        ModelDescriptor desc = new ModelDescriptor(
            "MyModel",
            "My Cool Model",
            new ModelSignature("", "", ""),
            new ModelStorageModelReader("/my_model"),
            new IgniteModelParser<>()
        );
        ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
        descStorage.put("my_model", desc);

加载模型的示例:

Ignite ignite = Ignition.ignite();

        ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
        ModelDescriptor desc = descStorage.get(mdl);

        Model<byte[], byte[]> infMdl = new SingleModelBuilder().build(desc.getReader(), desc.getParser());

        Vector input = VectorUtils.of(x);

        try {
            return deserialize(infMdl.predict(serialize(input)));
        }
        catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);
        }

其中 x -是双精度矢量, mdl -是模型名称。

注意:此API将在2.8版中提供。但是,如果您要从master分支构建Ignite,则可以立即尝试。