直接从Java保存H2O模型

时间:2017-05-30 18:02:24

标签: java maven h2o

我正在尝试直接从Java创建和保存生成的模型。 documentation指定了如何在R和Python中执行此操作,但不在Java中。之前曾问过类似的question,但没有提供真正的答案(除了链接到H2O doc,它不包含代码示例)。

这足以让我目前的目的得到一些能够将以下参考代码翻译成Java的指针。我主要是寻找有关从Maven repository导入的相关JAR的指导。

import h2o
h2o.init()
path = h2o.system_file("prostate.csv")
h2o_df = h2o.import_file(path)
h2o_df['CAPSULE'] = h2o_df['CAPSULE'].asfactor()
model = h2o.glm(y = "CAPSULE",
            x = ["AGE", "RACE", "PSA", "GLEASON"],
            training_frame = h2o_df,
            family = "binomial")
h2o.download_pojo(model)

2 个答案:

答案 0 :(得分:1)

我想我已经找到了 我的问题的答案。随后是一个独立的示例代码。但是,我仍然感谢社区的回答,因为我不知道这是否是最佳/惯用的方式。

package org.name.company;

import hex.glm.GLMModel;
import water.H2O;
import water.Key;
import water.api.StreamWriter;
import water.api.StreamingSchema;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import hex.glm.GLMModel.GLMParameters.Family;
import hex.glm.GLMModel.GLMParameters;
import hex.glm.GLM;
import water.util.JCodeGen;

import java.io.*;
import java.util.Map;

public class Launcher
{
    public static void initCloud(){
        String[] args = new String [] {"-name", "h2o_test_cloud"};
        H2O.main(args);
        H2O.waitForCloudSize(1, 10 * 1000);
    }

    public static void main( String[] args ) throws Exception {
        // Initialize the cloud
        initCloud();

        // Create a Frame object from CSV
        File f = new File("/path/to/data.csv");
        NFSFileVec nfs = NFSFileVec.make(f);
        Key frameKey = Key.make("frameKey");
        Frame fr = water.parser.ParseDataset.parse(frameKey, nfs._key);

        // Create a GLM and output coefficients
        Key modelKey = Key.make("modelKey");
        try {
            GLMParameters params = new GLMParameters();
            params._train = frameKey;
            params._response_column = fr.names()[1];
            params._intercept = true;
            params._lambda = new double[]{0};
            params._family = Family.gaussian;

            GLMModel model = new GLM(params).trainModel().get();
            Map<String, Double> coefs = model.coefficients();
            for(Map.Entry<String, Double> entry : coefs.entrySet()) {
                System.out.format("%s: %f\n", entry.getKey(), entry.getValue());
            }

            String filename = JCodeGen.toJavaId(model._key.toString()) + ".java";
            StreamingSchema ss = new StreamingSchema(model.new JavaModelStreamWriter(false), filename);
            StreamWriter sw = ss.getStreamWriter();
            OutputStream os = new FileOutputStream("/base/path/" + filename);
            sw.writeTo(os);

        } finally {
            if (fr != null) {
                fr.remove();
            }
        }
    }
}

答案 1 :(得分:1)

这样的事情能做到吗?

public void saveModel(URI uri, Keyed<Frame> model)
{
    Persist p = H2O.getPM().getPersistForURI(uri);
    OutputStream os = p.create(uri.toString(), true);
    model.writeAll(new AutoBuffer(os, true)).close();
}

确保URI具有正确的格式,否则H2O将在npe上中断。至于Maven,你应该能够逃脱h2o核心。

    <dependency>
        <groupId>ai.h2o</groupId>
        <artifactId>h2o-core</artifactId>
        <version>3.14.0.2</version>
    </dependency>