使用Encog

时间:2016-08-13 21:54:23

标签: c# neural-network encog

我想使用Encog在link处训练数据上的神经网络。有17个输入功能(2个数字,15个分类)和2个输出功能(均为分类)。

我想创建一个解决此问题的基本前馈网络,但到目前为止我的努力未能收敛。我的网络设计是:

  • 输入图层:57个节点
    • A-B列中的得分为2个节点
    • C-E列中先前出价的3个节点(使用-1表示无,0表示通行证,2表示2表示,3表示3表示)
    • G-Q列中6张牌的52个节点(在"六热"向量中)
  • 隐藏层:104个节点(仅基于2 * 57的猜测)
  • 输出图层:13个节点(3个非通过出价* 4个套装+ 1个通过出价)

我使用tanh作为激活功能并启用了偏置节点。创建此网络(在C#中)的调用是:

Encog.Util.Simple.EncogUtility.SimpleFeedForward(nInputs, nHidden, 0, nOutputs, true)

我是神经网络的新手,所以我真的不确定如何处理这个问题。到目前为止,我一直在使用反复试验,但我希望有更好的方法。感谢。

1 个答案:

答案 0 :(得分:0)

我在下面的代码中输入了您的数据。我收到了培训错误,看看能不能解决问题。

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Data;
using System.Data.OleDb;
using System.IO;


using Encog.Neural.Networks;
using Encog.Neural.Networks.Layers;
using Encog.Engine.Network.Activation;
using Encog.ML.Data;
using Encog.Neural.Networks.Training.Propagation.Resilient;
using Encog.ML.Train;
using Encog.ML.Data.Basic;
using Encog;
using System.Data;
using System.Data.OleDb;
using System.IO;


namespace encog_sample_csharp
{
    internal class Program
    {
        /// <summary>
        /// Input for the XOR function.
        /// </summary>
        const string FILENAME = @"c:\temp\BidTraining.csv";

        static DataTable dt = null;
        private static void Main(string[] args)
        {

            CSVReader reader = new CSVReader();
            DataSet ds = reader.ReadCSVFile(FILENAME, true);
            dt = ds.Tables["Table1"];


            // create a neural network, without using a factory
            var network = new BasicNetwork();
            network.AddLayer(new BasicLayer(null, true, 17));
            network.AddLayer(new BasicLayer(new ActivationSigmoid(), false, 2));
            network.Structure.FinalizeStructure();
            network.Reset();
            Dictionarys dict = new Dictionarys();

            // create training dat
            IMLDataSet dataSet = dict.GetDataSet(dt);

            // train the neural network
            IMLTrain train = new ResilientPropagation(network, dataSet);

            int epoch = 1;

            do
            {
                train.Iteration();
                Console.WriteLine(@"Epoch #" + epoch + @" Error:" + train.Error);
                epoch++;
            } while (train.Error > 0.01);

            train.FinishTraining();

            // test the neural network
            Console.WriteLine(@"Neural Network Results:");
            foreach (IMLDataPair pair in dataSet)
            {
                IMLData output = network.Compute(pair.Input);
                Console.WriteLine(pair.Input[0] + @"," + pair.Input[1]
                                  + @", actual=" + output[0] + @",ideal=" + pair.Ideal[0]);
            }

            EncogFramework.Instance.Shutdown();
        }
    }
    public class Dictionarys
    {
        public double[][] inputNeurons;
        public double[][] outputNeurons;
        public static Dictionary<string, double> bid = new Dictionary<string, double>(){
             {"None", -1.0},
             {"Pass", 0.0},
             {"One", 1.0},
             {"Two", 2.0},
             {"Three", 3.0},
             {"Four", 4.0},
             {"Five", 5.0},
             {"Six", 6.0}
        };
        public static Dictionary<string, double> rank = new Dictionary<string, double>() {
            {"Ace", 1.0},
            {"Two", 2.0},
            {"Three", 3.0},
            {"Four", 4.0},
            {"Five", 5.0},
            {"Six", 6.0},
            {"Seven", 7.0},
            {"Eight", 8.0},
            {"Nine", 9.0},
            {"Ten", 10.0},
            {"Jack", 11.0},
            {"Queen", 12.0},
            {"King", 13.0}
        };
        public static Dictionary<string, double> suit = new Dictionary<string, double>() {
            {"None",-1.0},
            {"SuitA",1.0},
            {"SuitB",2.0},
            {"SuitC",3.0},
            {"SuitD",4.0}
        };

        public IMLDataSet GetDataSet(DataTable dt)
        {
            inputNeurons = dt.AsEnumerable().Select(x => new[] {
                (double)x.Field<int>(0),
                (double)x.Field<int>(1),
                bid[x.Field<string>(2)],
                bid[x.Field<string>(3)],
                bid[x.Field<string>(4)],
                rank[x.Field<string>(5)],
                suit[x.Field<string>(6)],
                rank[x.Field<string>(7)],
                suit[x.Field<string>(8)],
                rank[x.Field<string>(9)],
                suit[x.Field<string>(10)],
                rank[x.Field<string>(11)],
                suit[x.Field<string>(12)],
                rank[x.Field<string>(13)],
                suit[x.Field<string>(14)],
                rank[x.Field<string>(15)],
                suit[x.Field<string>(16)]
            }).ToArray();

            outputNeurons = dt.AsEnumerable().Select(x => new[] {
                bid[x.Field<string>(17)],
                suit[x.Field<string>(18)]
            }).ToArray();

            IMLDataSet trainingSet = new BasicMLDataSet(inputNeurons, outputNeurons);    
            return trainingSet;
        }


    }
    public class CSVReader
    {

        public DataSet ReadCSVFile(string fullPath, bool headerRow)
        {

            string path = fullPath.Substring(0, fullPath.LastIndexOf("\\") + 1);
            string filename = fullPath.Substring(fullPath.LastIndexOf("\\") + 1);
            DataSet ds = new DataSet();

            try
            {
                if (File.Exists(fullPath))
                {
                    string ConStr = string.Format("Provider=Microsoft.Jet.OLEDB.4.0;Data Source={0}" + ";Extended Properties=\"Text;HDR={1};FMT=Delimited\\\"", path, headerRow ? "Yes" : "No");
                    string SQL = string.Format("SELECT * FROM {0}", filename);
                    OleDbDataAdapter adapter = new OleDbDataAdapter(SQL, ConStr);
                    adapter.Fill(ds, "TextFile");
                    ds.Tables[0].TableName = "Table1";
                }
                foreach (DataColumn col in ds.Tables["Table1"].Columns)
                {
                    col.ColumnName = col.ColumnName.Replace(" ", "_");
                }
            }

            catch (Exception ex)
            {
                Console.WriteLine(ex.Message);
            }
            return ds;
        }
    }
}