如何在Accord.Net中同步训练和测试码本

时间:2016-10-20 14:47:04

标签: c# tree random-forest accord.net

问题:是否存在将火车和测试集分开的随机森林示例?我在Accord-Net ML测试项目中找到的当前示例使用相同的数据进行培训和测试。

显然,我遇到的问题是在测试和训练集之间同步生成的标签(整数)。我正在生成火车标签:

int[] trainOutputs = trainCodebook.Translate("Output", trainLabels);

And the test labels similarly:

int[] testOutputs = testCodebook.Translate("Output", testLabels);

Finally I train with the train data and test with the test data:

var forest = teacher.Learn(trainVectors, trainOutputs);

int[] predicted = forest.Decide(testVectors);

除非列车和测试集中的前三行相同,否则标签不同,因此会产生非常高的错误率。

我试图用三元字符串手动创建我的代码簿:

new Codification("-1","0","1");

不幸的是,这会产生一个运行时错误,指出给定的键不在字典中。我确信有一种方法可以在两个单独的代码簿中同步密钥生成。我能够使用下面的代码工作如果我在我的测试数据的顶部添加三行我的列车数据,包含所有三个键。不是我首选的解决方案; =)

这是我正在进行的整个测试:

 [Test]
 public void test_learn()
 {
 Accord.Math.Random.Generator.Seed = 1;

    /////////// TRAINING SET ///////////
    // First, let's load the TRAINING set into an array of text that we can process
    string[][] text = Resources.train.Split(new[] { "\r\n" },
        StringSplitOptions.RemoveEmptyEntries).Apply(x => x.Split(','));

    int length = text[0].Length;
    List<int> columns = new List<int>();
    for (int i = 1; i < length; i++)
    {
        columns.Add(i);
    }
    double[][] trainVectors = text.GetColumns(columns.ToArray()).To<double[][]>();

    // The first column contains the expected ternary category (i.e. -1, 0, or 1)
    string[] trainLabels = text.GetColumn(0);
    var trainCodebook = new Codification("Output", trainLabels);
    int[] trainOutputs = trainCodebook.Translate("Output", trainLabels);

    ////////// TEST SET ////////////

    text = Resources.test.Split(new[] { "\r\n" },
        StringSplitOptions.RemoveEmptyEntries).Apply(x => x.Split(','));

    double[][] testVectors = text.GetColumns(columns.ToArray()).To<double[][]>();
    string[] testLabels = text.GetColumn(0);
    var testCodebook = new Codification("Output", testLabels);
    int[] testOutputs = testCodebook.Translate("Output", testLabels);

    var teacher = new RandomForestLearning()
    {
        NumberOfTrees = 10,
    };

    var forest = teacher.Learn(trainVectors, trainOutputs);
    int[] predicted = forest.Decide(testVectors);

    int lineNum = 1;
    foreach (int prediction in predicted)
    {
        Console.WriteLine("Prediction " + lineNum + ": " 
        + trainCodebook.Translate("Output", prediction));
        lineNum++;
    }
    // I'm using the test vectors to calculate the error rate
    double error = new ZeroOneLoss(testOutputs).Loss(forest.Decide(testVectors));

    Console.WriteLine("Error term is " + error);

    Assert.IsTrue(error < 0.20); // humble expectations ;-)
}

1 个答案:

答案 0 :(得分:1)

好吧,我明白了。请参阅以下代码:

好的,我想我能解决它。问题是DecisionTree中序列化的错误实现。幸运的是我们有代码 - 请参阅下面的修复:

namespace Accord.MachineLearning.DecisionTrees
{
  using System;
  using System.Collections.Generic;
  using System.Linq;
  using System.Text;
  using System.Threading.Tasks;
  using System.Data;
  using System.Runtime.Serialization;
  using System.Runtime.Serialization.Formatters.Binary;
  using System.IO;
  using Accord.Statistics.Filters;
  using Accord.Math;
  using AForge;
  using Accord.Statistics;
  using System.Threading;


/// <summary>
///   Random Forest.
/// </summary>
/// 
/// <remarks>
/// <para>
///   Represents a random forest of <see cref="DecisionTree"/>s. For 
///   sample usage and example of learning, please see the documentation
///   page for <see cref="RandomForestLearning"/>.</para>
/// </remarks>
/// 
/// <seealso cref="DecisionTree"/>
/// <seealso cref="RandomForestLearning"/>
/// 
[Serializable]
public class RandomForest : MulticlassClassifierBase, IParallel
{
    private DecisionTree[] trees;
    **[NonSerialized]
    private ParallelOptions parallelOptions;**


    /// <summary>
    ///   Gets the trees in the random forest.
    /// </summary>
    /// 
    public DecisionTree[] Trees
    {
        get { return trees; }
    }

    /// <summary>
    ///   Gets the number of classes that can be recognized
    ///   by this random forest.
    /// </summary>
    /// 
    [Obsolete("Please use NumberOfOutputs instead.")]
    public int Classes { get { return NumberOfOutputs; } }

    /// <summary>
    ///   Gets or sets the parallelization options for this algorithm.
    /// </summary>
    ///
    **public ParallelOptions ParallelOptions { get { return parallelOptions; } set { parallelOptions = value; } }**

    /// <summary>
    /// Gets or sets a cancellation token that can be used
    /// to cancel the algorithm while it is running.
    /// </summary>
    /// 
    public CancellationToken Token
    {
        get { return ParallelOptions.CancellationToken; }
        set { ParallelOptions.CancellationToken = value; }
    }

    /// <summary>
    ///   Creates a new random forest.
    /// </summary>
    /// 
    /// <param name="trees">The number of trees in the forest.</param>
    /// <param name="classes">The number of classes in the classification problem.</param>
    /// 
    public RandomForest(int trees, int classes)
    {
        this.trees = new DecisionTree[trees];
        this.NumberOfOutputs = classes;
        this.ParallelOptions = new ParallelOptions();
    }

    /// <summary>
    ///   Computes the decision output for a given input vector.
    /// </summary>
    /// 
    /// <param name="data">The input vector.</param>
    /// 
    /// <returns>The forest decision for the given vector.</returns>
    /// 
    [Obsolete("Please use Decide() instead.")]
    public int Compute(double[] data)
    {
        return Decide(data);
    }


    /// <summary>
    /// Computes a class-label decision for a given <paramref name="input" />.
    /// </summary>
    /// <param name="input">The input vector that should be classified into
    /// one of the <see cref="ITransform.NumberOfOutputs" /> possible classes.</param>
    /// <returns>A class-label that best described <paramref name="input" /> according
    /// to this classifier.</returns>
    public override int Decide(double[] input)
    {
        int[] responses = new int[NumberOfOutputs];
        Parallel.For(0, trees.Length, ParallelOptions, i =>
        {
            int j = trees[i].Decide(input);
            Interlocked.Increment(ref responses[j]);
        });

        return responses.ArgMax();
    }

   [OnDeserializing()]
    internal void OnDeserializingMethod(StreamingContext context)
    {
        this.ParallelOptions = new ParallelOptions();
    }
}
}