如何使用带有Retrofit和h2o绑定的Rest API获取H2O ModelMetricsBinominal

时间:2018-01-04 12:08:11

标签: h2o

我有几个(Binominal)-DRF模型,我想获取ModelMatricsBinominalV3对象来提取 thresholds_and_metric_scores 变量。我已经实现了一个没有改装和绑定的解决方案,但是我想使用h2o-bindings来发送和接收pojos,因为我当前的解决方案不是很优雅且非常容易出错。有没有人这样做过,可以分享他的代码?特别是我对提取给定阈值的准确度,F1分数,召回等感兴趣。

我目前的做法是:

h2oApi.predict(ModelMetricsListSchemaV3)

工作 - 但不包含 thresholds_and_metric_scores

在PostMan中调用POST / 3 / Predictions / models / {model} / frames / {frame}工作正常,并在json String中返回 thresholds_and_metric_scores 。怎么可能是因为h2oApi在内部调用POST / 3 / Predictions / models / {model} / frames / {frame}?!

这是我以前的实施方式:

public String getModelMetrics(String modelId, String frameId, double threshold){
  String url = buildHttpPath("/3/Predictions/models/" + modelId + "/frames/" + frameId);
  int metricIndex = Integer.MIN_VALUE;

  HttpClient client = HttpClientBuilder.create().build();
  HttpPost post = new HttpPost(url);
  HttpResponse response;
  String json = "";

  try
  {
     response = client.execute(post);
     json = EntityUtils.toString(response.getEntity());
  }
  catch (IOException exception)
  {
     LOG.error(exception.toString());
  }

  JsonObject var1 = new Gson().fromJson(json, JsonObject.class);
  JsonArray var2 = var1.getAsJsonArray("model_metrics");
  JsonElement var3 = var2.get(0);
  JsonElement var4 = ((JsonObject) var3).get("thresholds_and_metric_scores");
  JsonElement var5 = ((JsonObject) var4).get("data");
  JsonArray var6 = (JsonArray) ((JsonArray) var5).get(0);
  Double min = Double.MAX_VALUE;

  for (int i = 0; i < var6.size(); i++)
  {
     Double currentElement = var6.get(i).getAsDouble();
     Double diff = Math.abs(currentElement - threshold);

     if (diff < min)
     {
        min = diff;
        metricIndex = i;
     }
  }

  LOG.info("Received threshold is: " + threshold);
  LOG.info("Nearest Threshold is: " + var6.get(metricIndex).getAsDouble());

  JsonArray accuracyColumn = (JsonArray) ((JsonArray) var5).get(4);
  JsonArray f1Column = (JsonArray) ((JsonArray) var5).get(1);
  JsonArray recallColumn = (JsonArray) ((JsonArray) var5).get(6);
  JsonArray precisionColumn = (JsonArray) ((JsonArray) var5).get(5);
  JsonArray tpColumn = (JsonArray) ((JsonArray) var5).get(14);
  JsonArray tnColumn = (JsonArray) ((JsonArray) var5).get(11);
  JsonArray fpColumn = (JsonArray) ((JsonArray) var5).get(13);
  JsonArray fnColumn = (JsonArray) ((JsonArray) var5).get(12);

  Double accuracy = accuracyColumn.get(metricIndex).getAsDouble();
  Double f1 = f1Column.get(metricIndex).getAsDouble();
  Double recall = recallColumn.get(metricIndex).getAsDouble();
  Double precision = precisionColumn.get(metricIndex).getAsDouble();
  int tp = tpColumn.get(metricIndex).getAsInt();
  int tn = tnColumn.get(metricIndex).getAsInt();
  int fp = fpColumn.get(metricIndex).getAsInt();
  int fn = fnColumn.get(metricIndex).getAsInt();

  return accuracy.toString() + ";" + f1.toString() + ";" + recall.toString() + ";" + precision.toString() + ";" + tp + ";" + tn + ";" + fp + ";" + fn;}

提前谢谢!

0 个答案:

没有答案