我正在尝试使用ML.NET实现简单的线性回归。除了我从IEnumerable而不是从文件加载数据之外,我都遵循了Microsoft的教程。但是,我总是以零作为预测。回顾一下,我发现模型的参数为空(请参见下面的屏幕截图)。
这是我能够重现该问题的最简单程序。
class Data
{
public float Revenue { get; set; }
public float Customers { get; set; }
public float IsWeekend { get; set; }
}
class DataPrediction
{
[ColumnName("Score")]
public float RevenueEstimate { get; set; }
}
static async Task Main(string[] args)
{
List<Data> trainData = new List<Data>();
trainData.Add(new Data { IsWeekend = 0, Customers = 20, Revenue = 138 });
trainData.Add(new Data { IsWeekend = 0, Customers = 18, Revenue = 106 });
trainData.Add(new Data { IsWeekend = 0, Customers = 26, Revenue = 142 });
trainData.Add(new Data { IsWeekend = 0, Customers = 21, Revenue = 131 });
trainData.Add(new Data { IsWeekend = 0, Customers = 33, Revenue = 146 });
trainData.Add(new Data { IsWeekend = 1, Customers = 92, Revenue = 287 });
trainData.Add(new Data { IsWeekend = 1, Customers = 113, Revenue = 312 });
List<Data> testData = new List<Data>();
testData.Add(new Data { IsWeekend = 0, Customers = 27, Revenue = 136 });
testData.Add(new Data { IsWeekend = 0, Customers = 22, Revenue = 109 });
testData.Add(new Data { IsWeekend = 1, Customers = 87, Revenue = 256 });
MLContext mlContext = new MLContext();
IDataView trainDataView = mlContext.Data.LoadFromEnumerable(trainData);
IDataView testDataView = mlContext.Data.LoadFromEnumerable(testData);
var pipeline = mlContext.Transforms.CopyColumns("Label", "Revenue")
.Append(mlContext.Transforms.Concatenate("Features", "Customers", "IsWeekend"))
.Append(mlContext.Regression.Trainers.FastTree())
;
var model = pipeline.Fit(trainDataView);
var predictions = model.Transform(testDataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
var predictionFunction = mlContext.Model.CreatePredictionEngine<Data, DataPrediction>(model);
var prediction = predictionFunction.Predict(new Data { IsWeekend = 1, Customers = 108 });
}
如果我在右括号处中断,这就是我得到的:
当然,prediction.RevenueEstimate
为0,指标显示的平均绝对误差和RMS误差与标签的平均值相似。
在我看来,这是一个非常必要的程序,就像Microsoft在其教程中建议的那样,我正在使用FastTree LR培训器。我在做什么错了?
PS:使用ColumnName
属性修饰数据类无效。