训练后的模型拟合后为空

时间:2020-06-29 09:34:29

标签: ml.net

我正在尝试使用ML.NET实现简单的线性回归。除了我从IEnumerable而不是从文件加载数据之外,我都遵循了Microsoft的教程。但是,我总是以零作为预测。回顾一下,我发现模型的参数为空(请参见下面的屏幕截图)。

这是我能够重现该问题的最简单程序。

class Data
{
    public float Revenue { get; set; }
    public float Customers { get; set; }
    public float IsWeekend { get; set; }
}

class DataPrediction
{ 
    [ColumnName("Score")]
    public float RevenueEstimate { get; set; }
}

static async Task Main(string[] args)
{
    List<Data> trainData = new List<Data>();
    trainData.Add(new Data { IsWeekend = 0, Customers = 20, Revenue = 138 });
    trainData.Add(new Data { IsWeekend = 0, Customers = 18, Revenue = 106 });
    trainData.Add(new Data { IsWeekend = 0, Customers = 26, Revenue = 142 });
    trainData.Add(new Data { IsWeekend = 0, Customers = 21, Revenue = 131 });
    trainData.Add(new Data { IsWeekend = 0, Customers = 33, Revenue = 146 });
    trainData.Add(new Data { IsWeekend = 1, Customers = 92, Revenue = 287 });
    trainData.Add(new Data { IsWeekend = 1, Customers = 113, Revenue = 312 });

    List<Data> testData = new List<Data>();
    testData.Add(new Data { IsWeekend = 0, Customers = 27, Revenue = 136 });
    testData.Add(new Data { IsWeekend = 0, Customers = 22, Revenue = 109 });
    testData.Add(new Data { IsWeekend = 1, Customers = 87, Revenue = 256 });

    MLContext mlContext = new MLContext();
    IDataView trainDataView = mlContext.Data.LoadFromEnumerable(trainData);
    IDataView testDataView = mlContext.Data.LoadFromEnumerable(testData);

    var pipeline = mlContext.Transforms.CopyColumns("Label", "Revenue")
        .Append(mlContext.Transforms.Concatenate("Features", "Customers", "IsWeekend"))
        .Append(mlContext.Regression.Trainers.FastTree())
        ;
    var model = pipeline.Fit(trainDataView);

    var predictions = model.Transform(testDataView);
    var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");

    var predictionFunction = mlContext.Model.CreatePredictionEngine<Data, DataPrediction>(model);
    var prediction = predictionFunction.Predict(new Data { IsWeekend = 1, Customers = 108 });
}

如果我在右括号处中断,这就是我得到的:

enter image description here

当然,prediction.RevenueEstimate为0,指标显示的平均绝对误差和RMS误差与标签的平均值相似。

在我看来,这是一个非常必要的程序,就像Microsoft在其教程中建议的那样,我正在使用FastTree LR培训器。我在做什么错了?

PS:使用ColumnName属性修饰数据类无效。

0 个答案:

没有答案