如何使用字符串作为预测列

时间:2019-10-24 12:08:21

标签: ml.net

目标:
对PaymentType进行预测。目标名称或预测列是字符串值。

问题:
我检索到一条错误消息“ ArgumentOutOfRangeException:训练标签列'标签'类型不适用于回归:文本。类型必须为R4或R8。参数名称:data”

我缺少源代码的一部分吗?

谢谢!

数据:
https://github.com/dotnet/machinelearning/blob/master/test/data/taxi-fare-train.csv

https://github.com/dotnet/machinelearning/blob/master/test/data/taxi-fare-test.csv

信息:
我是ML.net的新手

代码:

using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using System;
using System.Threading.Tasks;

namespace TaxiFarePrediction
{
    class Program
    {
        const string _datapath = @".\Data\taxi-fare-train.csv";
        const string _testdatapath = @".\Data\taxi-fare-test.csv";
        const string _modelpath = @".\Data\Model.zip";

        static async Task Main(string[] args)
        {
            PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = await Train();
            Evaluate(model);

            var prediction = model.Predict(TestTrips.Trip1);
            Console.WriteLine("Predicted fare: {0}", prediction.PaymentType);

            Console.ReadLine();
        }

        static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train()
        {
            // Create learning pipeline
            var pipeline = new LearningPipeline
            {
                // Load and transform data
                new TextLoader(_datapath).CreateFrom<TaxiTrip>(separator: ','),

                 // Labeling
                new ColumnCopier(("PaymentType", "Label")),

                // Feature engineering
                new CategoricalOneHotVectorizer("VendorId",
                    "RateCode",
                    "PaymentType"),

                 // Combine features in a single vector
                new ColumnConcatenator("Features",
                    "VendorId",
                    "RateCode",
                    "PassengerCount",
                    "TripDistance",
                    "FareAmount"),

                // Add learning algorithm
                new FastTreeRegressor()
            };

            // Train the model
            PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();

            // Save the model to a zip file
            await model.WriteAsync(_modelpath);

            return model;
        }

        private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
        {
            // Load test data
            var testData = new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ',');

            // Evaluate test data
            var evaluator = new RegressionEvaluator();
            RegressionMetrics metrics = evaluator.Evaluate(model, testData);

            // Display regression evaluation metrics
            Console.WriteLine("Rms=" + metrics.Rms);
            Console.WriteLine("RSquared = " + metrics.RSquared);
        }
    }
}

namespace TaxiFarePrediction
{
    static class TestTrips
    {
        internal static readonly TaxiTrip Trip1 = new TaxiTrip
        {
            VendorId = "VTS",
            RateCode = "1",
            PassengerCount = 1,
            TripDistance = 10.33f,
            PaymentType = "", 
            FareAmount = 7,
            //FareAmount = 0 // predict it. actual = 29.5
        };
    }
}

using Microsoft.ML.Runtime.Api;

namespace TaxiFarePrediction
{
    public class TaxiTrip
    {
        [Column("0")]
        public string VendorId;

        [Column("1")]
        public string RateCode;

        [Column("2")]
        public float PassengerCount;

        [Column("3")]
        public float TripTime;

        [Column("4")]
        public float TripDistance;

        [Column("5")]
        public string PaymentType;

        [Column("6")]
        public float FareAmount;
    }

}

using Microsoft.ML.Runtime.Api;

namespace TaxiFarePrediction
{
    public class TaxiTripFarePrediction
    {

        /*
        [ColumnName("Score")]
        public float FareAmount;

        */
        [ColumnName("Score")]
        public string PaymentType;
    }
}

enter image description here

enter image description here

1 个答案:

答案 0 :(得分:0)

PaymentType是分类变量。查看有关多类分类的ML.NET教程,它们应该可以使您解除封锁。