目标:
对PaymentType进行预测。目标名称或预测列是字符串值。
问题:
我检索到一条错误消息“ ArgumentOutOfRangeException:训练标签列'标签'类型不适用于回归:文本。类型必须为R4或R8。参数名称:data”
我缺少源代码的一部分吗?
谢谢!
数据:
https://github.com/dotnet/machinelearning/blob/master/test/data/taxi-fare-train.csv
https://github.com/dotnet/machinelearning/blob/master/test/data/taxi-fare-test.csv
信息:
我是ML.net的新手
代码:
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using System;
using System.Threading.Tasks;
namespace TaxiFarePrediction
{
class Program
{
const string _datapath = @".\Data\taxi-fare-train.csv";
const string _testdatapath = @".\Data\taxi-fare-test.csv";
const string _modelpath = @".\Data\Model.zip";
static async Task Main(string[] args)
{
PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = await Train();
Evaluate(model);
var prediction = model.Predict(TestTrips.Trip1);
Console.WriteLine("Predicted fare: {0}", prediction.PaymentType);
Console.ReadLine();
}
static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train()
{
// Create learning pipeline
var pipeline = new LearningPipeline
{
// Load and transform data
new TextLoader(_datapath).CreateFrom<TaxiTrip>(separator: ','),
// Labeling
new ColumnCopier(("PaymentType", "Label")),
// Feature engineering
new CategoricalOneHotVectorizer("VendorId",
"RateCode",
"PaymentType"),
// Combine features in a single vector
new ColumnConcatenator("Features",
"VendorId",
"RateCode",
"PassengerCount",
"TripDistance",
"FareAmount"),
// Add learning algorithm
new FastTreeRegressor()
};
// Train the model
PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();
// Save the model to a zip file
await model.WriteAsync(_modelpath);
return model;
}
private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
{
// Load test data
var testData = new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ',');
// Evaluate test data
var evaluator = new RegressionEvaluator();
RegressionMetrics metrics = evaluator.Evaluate(model, testData);
// Display regression evaluation metrics
Console.WriteLine("Rms=" + metrics.Rms);
Console.WriteLine("RSquared = " + metrics.RSquared);
}
}
}
namespace TaxiFarePrediction
{
static class TestTrips
{
internal static readonly TaxiTrip Trip1 = new TaxiTrip
{
VendorId = "VTS",
RateCode = "1",
PassengerCount = 1,
TripDistance = 10.33f,
PaymentType = "",
FareAmount = 7,
//FareAmount = 0 // predict it. actual = 29.5
};
}
}
using Microsoft.ML.Runtime.Api;
namespace TaxiFarePrediction
{
public class TaxiTrip
{
[Column("0")]
public string VendorId;
[Column("1")]
public string RateCode;
[Column("2")]
public float PassengerCount;
[Column("3")]
public float TripTime;
[Column("4")]
public float TripDistance;
[Column("5")]
public string PaymentType;
[Column("6")]
public float FareAmount;
}
}
using Microsoft.ML.Runtime.Api;
namespace TaxiFarePrediction
{
public class TaxiTripFarePrediction
{
/*
[ColumnName("Score")]
public float FareAmount;
*/
[ColumnName("Score")]
public string PaymentType;
}
}
答案 0 :(得分:0)
PaymentType是分类变量。查看有关多类分类的ML.NET教程,它们应该可以使您解除封锁。