我从h2o下载了一个pojo,编译了它,但我该如何使用它?

时间:2017-03-05 02:11:23

标签: java python pojo h2o

我正在使用以下示例代码下载我从此post找到的pojo:

import h2o
 h2o.init()
 iris_df = h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/iris/iris.csv")
 from h2o.estimators.glm import H2OGeneralizedLinearEstimator
 predictors = iris_df.columns[0:4]
 response_col = "C5"
 train,valid,test = iris_df.split_frame([.7,.15], seed =1234)
 glm_model = H2OGeneralizedLinearEstimator(family="multinomial")
 glm_model.train(predictors, response_col, training_frame = train, validation_frame = valid)
 h2o.download_pojo(glm_model, path = '/Users/your_user_name/Desktop/', get_jar = True)

当我打开下载的java文件时,我会给出一些如何编译它的说明。以下编译成功:

javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m GLM_model_python_1488677745392_2.java

现在,我不确定如何使用它。我尝试了以下内容:

java -cp h2o-genmodel.jar javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m GLM_model_python_1488677745392_2.java

以下是pojo中的代码:

/*
  Licensed under the Apache License, Version 2.0
    http://www.apache.org/licenses/LICENSE-2.0.html

  AUTOGENERATED BY H2O at 2017-03-05T01:51:46.237Z
  3.10.3.2

  Standalone prediction code with sample test data for GLMModel named GLM_model_python_1488677745392_2

  How to download, compile and execute:
      mkdir tmpdir
      cd tmpdir
      curl http:/10.0.0.4/10.0.0.4:54321/3/h2o-genmodel.jar > h2o-genmodel.jar
      curl http:/10.0.0.4/10.0.0.4:54321/3/Models.java/GLM_model_python_1488677745392_2 > GLM_model_python_1488677745392_2.java
      javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m GLM_model_python_1488677745392_2.java

     (Note:  Try java argument -XX:+PrintCompilation to show runtime JIT compiler behavior.)
*/
import java.util.Map;
import hex.genmodel.GenModel;
import hex.genmodel.annotations.ModelPojo;

@ModelPojo(name="GLM_model_python_1488677745392_2", algorithm="glm")
public class GLM_model_python_1488677745392_2 extends GenModel {
  public hex.ModelCategory getModelCategory() { return hex.ModelCategory.Multinomial; }

  public boolean isSupervised() { return true; }
  public int nfeatures() { return 4; }
  public int nclasses() { return 3; }

  // Names of columns used by model.
  public static final String[] NAMES = NamesHolder_GLM_model_python_1488677745392_2.VALUES;
  // Number of output classes included in training data response column.
  public static final int NCLASSES = 3;

  // Column domains. The last array contains domain of response column.
  public static final String[][] DOMAINS = new String[][] {
    /* C1 */ null,
    /* C2 */ null,
    /* C3 */ null,
    /* C4 */ null,
    /* C5 */ GLM_model_python_1488677745392_2_ColInfo_4.VALUES
  };
  // Prior class distribution
  public static final double[] PRIOR_CLASS_DISTRIB = {0.2818181818181818,0.33636363636363636,0.38181818181818183};
  // Class distribution used for model building
  public static final double[] MODEL_CLASS_DISTRIB = null;

  public GLM_model_python_1488677745392_2() { super(NAMES,DOMAINS); }
  public String getUUID() { return Long.toString(-5598526670666235824L); }

  // Pass in data in a double[], pre-aligned to the Model's requirements.
  // Jam predictions into the preds[] array; preds[0] is reserved for the
  // main prediction (class for classifiers or value for regression),
  // and remaining columns hold a probability distribution for classifiers.
  public final double[] score0( double[] data, double[] preds ) {
    final double [] b = BETA.VALUES;
    for(int i = 0; i < 0; ++i) if(Double.isNaN(data[i])) data[i] = CAT_MODES.VALUES[i];
    for(int i = 0; i < 4; ++i) if(Double.isNaN(data[i + 0])) data[i+0] = NUM_MEANS.VALUES[i];
    preds[0] = 0;
    for(int c = 0; c < 3; ++c){
      preds[c+1] = 0;
      for(int i = 0; i < 4; ++i)
        preds[c+1] += b[0+i + c*5]*data[i];
      preds[c+1] += b[4 + c*5]; // reduce intercept
    }
    double max_row = 0;
    for(int c = 1; c < preds.length; ++c) if(preds[c] > max_row) max_row = preds[c];
    double sum_exp = 0;
    for(int c = 1; c < preds.length; ++c) { sum_exp += (preds[c] = Math.exp(preds[c]-max_row));}
    sum_exp = 1/sum_exp;
    double max_p = 0;
    for(int c = 1; c < preds.length; ++c) if((preds[c] *= sum_exp) > max_p){ max_p = preds[c]; preds[0] = c-1;};
    return preds;
  }
    public static class BETA implements java.io.Serializable {
      public static final double[] VALUES = new double[15];
      static {
        BETA_0.fill(VALUES);
      }
      static final class BETA_0 implements java.io.Serializable {
        static final void fill(double[] sa) {
          sa[0] = -1.4700470387418272;
          sa[1] = 4.26067731522767;
          sa[2] = -2.285756276489862;
          sa[3] = -4.312931422791621;
          sa[4] = 5.231215014401568;
          sa[5] = 1.7769023115830205;
          sa[6] = -0.2534145823550425;
          sa[7] = -0.9887536067536575;
          sa[8] = -1.2706135235877678;
          sa[9] = -4.319817154759757;
          sa[10] = 0.0;
          sa[11] = -3.024835247270209;
          sa[12] = 3.8622405283810464;
          sa[13] = 7.018262604176258;
          sa[14] = -22.702291637028203;
        }
      }
}
// Imputed numeric values
    static class NUM_MEANS implements java.io.Serializable {
      public static final double[] VALUES = new double[4];
      static {
        NUM_MEANS_0.fill(VALUES);
      }
      static final class NUM_MEANS_0 implements java.io.Serializable {
        static final void fill(double[] sa) {
          sa[0] = 5.90272727272727;
          sa[1] = 3.024545454545454;
          sa[2] = 3.9490909090909097;
          sa[3] = 1.2872727272727267;
        }
      }
}
// Imputed categorical values.
    static class CAT_MODES implements java.io.Serializable {
      public static final int[] VALUES = new int[0];
      static {
      }
}
    // Categorical Offsets
    public static final int[] CATOFFS = {0};
}
// The class representing training column names
class NamesHolder_GLM_model_python_1488677745392_2 implements java.io.Serializable {
  public static final String[] VALUES = new String[4];
  static {
    NamesHolder_GLM_model_python_1488677745392_2_0.fill(VALUES);
  }
  static final class NamesHolder_GLM_model_python_1488677745392_2_0 implements java.io.Serializable {
    static final void fill(String[] sa) {
      sa[0] = "C1";
      sa[1] = "C2";
      sa[2] = "C3";
      sa[3] = "C4";
    }
  }
}
// The class representing column C5
class GLM_model_python_1488677745392_2_ColInfo_4 implements java.io.Serializable {
  public static final String[] VALUES = new String[3];
  static {
    GLM_model_python_1488677745392_2_ColInfo_4_0.fill(VALUES);
  }
  static final class GLM_model_python_1488677745392_2_ColInfo_4_0 implements java.io.Serializable {
    static final void fill(String[] sa) {
      sa[0] = "Iris-setosa";
      sa[1] = "Iris-versicolor";
      sa[2] = "Iris-virginica";
    }
  }
}

现在,我想我需要打电话给得分0。我已经想出如何创建自己的main.java并创建一个main()的入口点,以便我可以实例化对象并调用score0,但我不知道它应该如何工作。我期待提供4个双打并返回一个类别,但相反,该函数需要两个double []而我无法弄清楚究竟放在哪里以及如何读取结果。这是我的主要内容:

public class Main {
  public static void main(String[] args) {
      double[] input = {4.6, 3.1, 1.5, 0.2};
      double[] output = new double[4];
      GLM_model_python_1488677745392_2 m = new GLM_model_python_1488677745392_2();
      double[] t = m.score0(input,output);
      for(int i = 0; i < t.length; i++) System.out.println(t[i]);
  }
}

我实际上已经收到了大量数据,但我不知道它有什么意义。我想我完全错误地使用了第二个参数,但我不知道该怎么做。这是输出:

0.0
0.9976588811416329
0.0023411188583572825
9.662837354438092E-15

3 个答案:

答案 0 :(得分:3)

几点:

对于那些真正想直接调用score0的人来说,最好的文档就是EasyPredictModelWrapper代码。


以下是如何使用Easy API进行新预测的文档(H2O版本3.10.4.1)中的POJO使用代码段:

import java.io.*;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.prediction.*;

public class main {
  private static String modelClassName = "gbm_pojo_test";

  public static void main(String[] args) throws Exception {
    hex.genmodel.GenModel rawModel;
    rawModel = (hex.genmodel.GenModel) Class.forName(modelClassName).newInstance();
    EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel);
    //
    // By default, unknown categorical levels throw PredictUnknownCategoricalLevelException.
    // Optionally configure the wrapper to treat unknown categorical levels as N/A instead
    // and strings that cannot be converted to numbers also to N/As:
    //
    //     EasyPredictModelWrapper model = new EasyPredictModelWrapper(
    //         new EasyPredictModelWrapper.Config()
    //             .setModel(rawModel)
    //             .setConvertUnknownCategoricalLevelsToNa(true)
    //             .setConvertInvalidNumbersToNa(true)
    //     );

    RowData row = new RowData();
     row.put("Year", "1987");
     row.put("Month", "10");
     row.put("DayofMonth", "14");
     row.put("DayOfWeek", "3");
     row.put("CRSDepTime", "730");
     row.put("UniqueCarrier", "PS");
     row.put("Origin", "SAN");
     row.put("Dest", "SFO");

    BinomialModelPrediction p = model.predictBinomial(row);
    System.out.println("Label (aka prediction) is flight departure delayed: " + p.label);
    System.out.print("Class probabilities: ");
    for (int i = 0; i < p.classProbabilities.length; i++) {
      if (i > 0) {
        System.out.print(",");
      }
      System.out.print(p.classProbabilities[i]);
    }
    System.out.println("");
  }
}

将新的预测值填入row(只是一张地图)之后,您拨打predictBinomial()进行预测。

几乎完全相同的代码可用于MOJO,除非您需要从数据文件而不是从类中实例化模型。所以代替POJO的代码:

    hex.genmodel.GenModel rawModel;
    rawModel = (hex.genmodel.GenModel) Class.forName(modelClassName).newInstance();
    EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel);

你有MOJO的代码:

    EasyPredictModelWrapper model = new EasyPredictModelWrapper(MojoModel.load("GBM_model_R_1475248925871_74.zip"));

答案 1 :(得分:2)

如果有人找到这个帖子,我发现了一种更简单的方法来利用下载的pojo。 H2o蒸汽很好地处理它:

~/steam-1.1.6-linux-amd64 > java -jar var/master/assets/jetty-runner.jar --port 8888 var/master/assets/ROOT.war &
curl -X POST --form pojo=@/home/tome/pojo/DL_defaults.java --form jar=@/home/tome/pojo/h2o-genmodel.jar localhost:8888/makewar > example.war
java -jar /home/tome/steam-1.1.6-linux-amd64/var/master/assets/jetty-runner.jar --port 7077 example.war

然后你可以查询它:

01:34:57 PS C:\dropbox\scripts> Invoke-RestMethod "http://notexist.eastus.cloudapp.azure.com:7077/predict?C1=4.6&C2=3.1&C3=1.5&C4=0.2"


labelIndex label       classProbabilities
---------- -----       ------------------
         0 Iris-setosa {0.9976588811416329, 0.0023411188583572825, 9.66283735443809E-15}

主http:页面提供了一个很好的界面来构建您的查询,如果您不喜欢上述内容,完整版Steam提供了一种直接连接到H2O并进行下载,转换和只需点击几下即可为您部署模型。

答案 2 :(得分:0)

没关系 - 看起来我毕竟是正确的。输出指的是我拥有的3个类别:

t[1] is setosa
t[2] is versicolor
t[3] is virginica

0.9976588811416329指的是它认为在该类别中的百分比。所以数据是99.8%setosa。