如何获得ML.NET中的偏见和权重?

时间:2019-11-17 19:07:40

标签: c# neural-network linear-regression ml.net

我在ML.NET中获得了线性回归模型,并且预测工作正常:

  MLContext mlContext = new MLContext(seed: 0);
            List<TwoInputRegressionModel> inputs = new List<TwoInputRegressionModel>();
            foreach (var JahrMitCO in ListWithCO)
            {
                float tempyear = JahrMitCO.Year;
                foreach (var JahrMitPopulation in Population)
                {
                    if (JahrMitPopulation.Year == tempyear)
                    {
                        inputs.Add(new TwoInputRegressionModel() { Year = tempyear, Population = JahrMitPopulation.Value, Co2 = JahrMitCO.Value });
                    }
                }
            }
            var model = Train(mlContext, inputs);
            TestSinglePrediction(mlContext, model); //works

但是我想知道如何获得训练模型的参数(权重+偏差)?我确实知道ITransformer类(这里称为model)确实包含Model属性,但是尝试像documentation上所述将其转换为LinearRegressionModelParameters类无效:

 LinearRegressionModelParameters originalModelParameters = ((ISingleFeaturePredictionTransformer<object>)model).Model as LinearRegressionModelParameters; //Exception:System.InvalidCastException
  

类型的对象   Microsoft.ML.Data.TransformerChain 1[Microsoft.ML.Data.RegressionPredictionTransformer 1 [Microsoft.ML.Trainers.FastTree.FastTreeRegressionModelParameters]]   不能转换为   Microsoft.ML.ISingleFeaturePredictionTransformer`1 [System.Object]

如何访问模型参数?

1 个答案:

答案 0 :(得分:0)

在这种情况下,问题在于您的model对象不是ISingleFeaturePredictionTransformer,而是一个TransformerChain对象(即一连串的变压器),其中{{ 1}}是“预测转换器”。

要解决此问题,请先将LastTransformer投射到model,然后可以得到TransformerChain<RegressionPredictionTransformer<FastTreeRegressionModelParameters>>,它会返回LastTransformer,从那里可以得到{ {1}}属性。

如果您在编译时不知道TransformerChain将包含哪种具体类型的变压器,则可以将RegressionPredictionTransformer<FastTreeRegressionModelParameters>转换为Model并获得链中的model变压器。您可以将其强制转换为IEnumerable<ITransformer>以获得.Last()属性。

ISingleFeaturePredictionTransformer<object>

从那里,您可以将Model转换为碰巧是任何特定的 ITransformer model = ...; IEnumerable<ITransformer> chain = model as IEnumerable<ITransformer>; ISingleFeaturePredictionTransformer<object> predictionTransformer = chain.Last() as ISingleFeaturePredictionTransformer<object>; object modelParameters = predictionTransformer.Model; 类。

注意:从您的异常消息中,您并不是在训练线性回归模型,而是在训练快速树模型。基于树的模型将无法强制转换为modelParameters,因此您将不会看到偏差和权重,而是会看到树信息。