自组织地图将所有训练数据放在一个坐标

时间:2015-11-20 08:40:36

标签: som

根据ai-junkie网站上的探索编写了这个SOM,它似乎将所有输入向量放在一个单独的公共坐标中,我做错了什么?我怀疑它是训练循环,因为这是该网站不清楚的一部分,任何帮助赞赏。

 static void Main(string[] args)
    {
        SOM som = new SOM(3, 5, 2, 0.1, new double[][] { new double[] {0,0,0 }, new double[] {0,0,0 } },0.00000001);
        List<string> lines = new List<string>();
        lines = System.IO.File.ReadAllLines("C:/Food.txt").ToList();
        som.training_data = new double[lines.Count][];

        for (int i = 0; i < som.training_data.Length; i++)
        {
            string[] h = lines[i].Split(',');
            som.training_data[i] = new double[h.Length-1];
            for (int j = 1; j < h.Length; j++)
            {
                som.training_data[i][j - 1] = Convert.ToSingle(h[j]);
            }
        }

        som.Train();
        for (int i = 0; i < som.training_data.Length; i++)
        {
            var best = som.Classify(som.training_data[i]);
            Console.WriteLine(lines[i].Split(',')[0]+" "+best[0]+" "+best[1]);
        }
        Console.ReadLine();

    }
}
public class MathUtil
{
    public static double Exp(double map_rad, double t, double lambda)
    {
        return map_rad * Math.Exp(-1 * (t / lambda));
    }
    public static double ExpLearn(double t, double lambda)
    {
        return Math.Exp(-1 * (t / lambda));
    }
    public static double Gaussian(int iteration, double distance, double neigh)
    {
        return Math.Exp(-1 * ((distance * distance) / 2 * neigh * neigh));
    }
}
public class Node
{
    public double[] vec;
    public double error;
    Random r = new Random();
    public Node(int dim)
    {
        vec = new double[dim];
        for (int i = 0; i < dim; i++)
        {
            vec[i] = r.NextDouble();
        }
    }
    public double Distance(double[] vec)
    {
        double sum = 0;
        for (int i = 0; i < vec.Length; i++)
        {
            sum += Math.Pow(this.vec[i] - vec[i], 2);
        }
        return Math.Sqrt(sum);
    }
    public static double Distance(double[] vec,double[] vec2)
    {
        double sum = 0;
        for (int i = 0; i < vec.Length; i++)
        {
            sum += Math.Pow(vec[i] - vec2[i], 2);
        }
        return Math.Sqrt(sum);
    }
    public double UpdateWeight(double[] inp_vec, double learn, int iteration, double neigh)
    {
        double sum = 0;
        for (int i = 0; i < inp_vec.Length; i++)
        {
            double delta = learn * MathUtil.Gaussian(iteration, Distance(inp_vec), neigh) * (vec[i] - inp_vec[i]);
            vec[i] = vec[i] + delta;
            sum += delta;
        }
        error = sum / vec.Length;
        return sum / vec.Length;
    }

    public bool InRad(double radius, int[] pos_win, int[] pos_me)
    {
        double square_sum = 0;
        for (int i = 0; i < pos_me.Length; i++)
        {
            square_sum += Math.Pow(pos_me[i] - pos_win[i], 2);
        }
        if (Math.Sqrt(square_sum) < radius)
        {
            return true;
        }
        return false;
    }
}

public class SOM
{
    public Node[,] nodes;
    public double width;
    public Random r = new Random();
    public double[][] training_data;
    public double height;
    public double Max_Error;
    public double learn_t0;

    public int[] Classify(double[] vec)
    {
        int[] best = new int[2];
        best = GetBMU(vec);
        return best;
    }
    public double GetError()
    {
        double error = 0;
        for (int i = 0; i < nodes.GetLength(0); i++)
        {
            for (int j = 0; j < nodes.GetLength(1); j++)
            {
                error += nodes[i, j].error;
            }
        }
        return error;
    }
    public double neighboorhood_radius(int iteration)
    {
        return MathUtil.Exp(Map_Radius_t0, iteration, lambda(iteration));
    }
    public double LearnFac(int iteration)
    {
        return learn_t0 * MathUtil.ExpLearn(iteration, lambda(iteration));
    }
    public double lambda(int iteration)
    {
        return iteration / Math.Log(Map_Radius_t0);
    }
    public double Map_Radius_t0
    {
        get
        {
            return Math.Max(width, height) / 2;
        }
    }

    public int[] GetBMU(double[] ran_tr_vec)
    {
        int[] best = new int[2];
        double smallest = double.MaxValue;
        for (int i = 0; i < nodes.GetLength(0); i++)
        {
            for (int j = 0; j < nodes.GetLength(1); j++)
            {
                if (Node.Distance(nodes[i, j].vec, ran_tr_vec) < smallest)
                {
                    best[0] = i;
                    best[1] = j;
                }
            }
        }
        return best;
    }
    public void Train()
    {
        for (int u = 0; u < 10000; u++)//while(GetError()>Max_Error)//
        {
            for (int ind = 0; ind < training_data.Length; ind++)
            {
                int iter = u;
                var inp_vec = training_data[ind];//r.Next(0, training_data.Length)];
                int[] best = GetBMU(inp_vec);
                #region Update_Weights
                for (int i = 0; i < nodes.GetLength(0); i++)
                {
                    for (int j = 0; j < nodes.GetLength(1); j++)
                    {
                        if (nodes[i, j].InRad(neighboorhood_radius(iter), best, new int[] { i, j }))
                        {
                            nodes[i, j].UpdateWeight(inp_vec, LearnFac(iter), iter, neighboorhood_radius(iter));
                        }
                    }
                }
                #endregion
            }
            Console.WriteLine(GetError());
        }
    }
    public SOM(int dim, int len, int bredth, double learn, double[][] tr_data, double Max_Error)
    {
        #region ini
        training_data = tr_data;
        learn_t0 = learn;
        width = bredth;
        height = len;
        this.Max_Error = Max_Error;
        nodes = new Node[len, bredth];
        for (int i = 0; i < nodes.GetLength(0); i++)
        {
            for (int j = 0; j < nodes.GetLength(1); j++)
            {
                nodes[i, j] = new Node(dim);
            }
        }
        #endregion

    }
}

}

0 个答案:

没有答案