How to pass dynamic classes to a generic function?

时间:2019-01-09 21:56:26

标签: c# class generics dynamic ml.net

I want to make a machine learning api for use with a web application, the field names will be passed to the api with their data types.

Currently I am making a class at runtime with the code provided in this answer: https://stackoverflow.com/a/3862241

The problem arises when I need to call the ML.NET PredictionFunction, I can't pass in the types for the generic function since they are made at runtime. I've tried using reflection to call it however it seems to not be able to find the function.

NOTE: Right now the docs for ML.NET is being updated for 0.9.0 so it is unavailable.

What I've tried is this (minimal):

Type[] typeArgs = { generatedType, typeof(ClusterPrediction) };
object[] parametersArray = { mlContext }; // value

MethodInfo method = typeof(TransformerChain).GetMethod("MakePredictionFunction");
if (method == null) { // Using PredictionFunctionExtensions helps here
  Console.WriteLine("Method not found!");
}
MethodInfo generic = method.MakeGenericMethod(typeArgs);
var temp = generic.Invoke(model, parametersArray);

The full (revised and trimmed) source (for more context): Program.cs

namespace Generic {
  class Program {
    public class GenericData {
      public float SepalLength;
      public float SepalWidth;
      public float PetalLength;
      public float PetalWidth;
    }
    public class ClusterPrediction {
      public uint PredictedLabel;
      public float[] Score;
    }

    static void Main(string[] args) {
      List<Field> fields = new List<Field>() {
                new Field(){ name="SepalLength", type=typeof(float)},
                new Field(){ name="SepalWidth", type=typeof(float)},
                new Field(){ name="PetalLength", type=typeof(float)},
                new Field(){ name="PetalWidth", type=typeof(float)},
            };
      var generatedType = GenTypeBuilder.CompileResultType(fields);

      var mlContext = new MLContext(seed: 0);
      TextLoader textLoader = mlContext.Data.TextReader(new TextLoader.Arguments() {
        Separator = ",",
        Column = new[]
        {
          new TextLoader.Column("SepalLength", DataKind.R4, 0),
          new TextLoader.Column("SepalWidth", DataKind.R4, 1),
          new TextLoader.Column("PetalLength", DataKind.R4, 2),
          new TextLoader.Column("PetalWidth", DataKind.R4, 3)
        }
      });
      IDataView dataView = textLoader.Read(Path.Combine(Environment.CurrentDirectory, "Data", "flowers.txt"););

      var pipeline = mlContext.Transforms
        .Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
        .Append(mlContext.Clustering.Trainers.KMeans("Features", clustersCount: 3));
      var model = pipeline.Fit(dataView);

      Type[] typeArgs = { generatedType, typeof(ClusterPrediction) };
      object[] parametersArray = { mlContext }; // value

      MethodInfo method = typeof(TransformerChain).GetMethod("MakePredictionFunction");
      if (method == null) { // Using PredictionFunctionExtensions helps here
        Console.WriteLine("Method not found!");
      }
      MethodInfo generic = method.MakeGenericMethod(typeArgs);
      var temp = generic.Invoke(model, parametersArray);

      var prediction = temp.Predict(new GenericData {SepalLength = 5.6f, SepalWidth = 2.5f,
                                                     PetalLength = 3.9f, PetalWidth = 1.1f});
    }
  }
}

2 个答案:

答案 0 :(得分:0)

Try reading your test data in an IDataView, than pass that IDataView to model.Transform();

That should insert the Score and the PredictedLabel as separate columns in your test data.

答案 1 :(得分:0)

似乎,在尝试反映MakePredictionFunction时,您将TransformerChain<TLastTransformer>类型(是可实例化的泛型类型)与静态类TransformerChain混淆了。

但是,即使对TransformerChain<TLastTransformer>的思考也不会成功,因为MakePredictionFunction不是此类型声明的方法。相反,MakePredictionFunction是在静态类PredictionFunctionExtensions⁽¹⁾中声明的扩展方法。

因此,要获取MakePredictionFunction MethodInfo ,请尝试以下操作:

MethodInfo method = typeof(PredictionFunctionExtensions).GetMethod("MakePredictionFunction");



⁽¹⁾ 我不确定100%确定 PredictionFunctionExtensions 驻留在哪个名称空间。搜索ML.NET 0.9.0 API文档,似乎它驻留在 Microsoft.ML.Runtime.Data 中命名空间。但是尝试访问 MakePredictionFunction 的实际文档页面当前仅会导致404错误,因此该信息可能不准确(我不是ML.NET用户,所以我可以不验证):-(