ANN:学习矢量量化不起作用

时间:2015-12-22 12:21:02

标签: c algorithm machine-learning neural-network unsupervised-learning

我希望有人可以帮助我: 我试图实现神经网络来查找数据集群,这些数据集呈现为2D集群。我试图遵循wikipedia中描述的标准算法:我寻找每个数据点的最小距离,并更新该神经元朝向数据点的权重。当总距离足够小时,我就停止这样做了。

我的结果是找到大多数聚类,但在视图上是错误的,虽然它计算了一个永久距离但它不再收敛。我的错误在哪里?

typedef struct{
    double x;
    double y;
}Data;

typedef struct{
    double x;
    double y;
}Neuron;

typedef struct{
    size_t numNeurons;
    Neuron* neurons;
}Network;

int main(void){
    srand(time(NULL));

    Data trainingData[1000];
    size_t sizeTrainingData = 0;
    size_t sizeClasses = 0;
    Network network;

    getData(trainingData, &sizeTrainingData, &sizeClasses);

    initializeNetwork(&network, sizeClasses);
    normalizeData(trainingData, sizeTrainingData);
    train(&network, trainingData, sizeTrainingData);

    return 0;
}

void train(Network* network, Data trainingData[], size_t sizeTrainingData){
    for(int epoch=0; epoch<TRAINING_EPOCHS; ++epoch){
        double learningRate = getLearningRate(epoch);
        double totalDistance = 0;
        for(int i=0; i<sizeTrainingData; ++i){
            Data currentData = trainingData[i];
            int winningNeuron = 0;
            totalDistance += findWinningNeuron(network, currentData, &winningNeuron);
            //update weight
            network->neurons[i].x += learningRate * (currentData.x - network->neurons[i].x);
            network->neurons[i].y += learningRate * (currentData.y - network->neurons[i].y);
        }
        if(totalDistance<MIN_TOTAL_DISTANCE) break;
    }
}

double getLearningRate(int epoch){
    return LEARNING_RATE * exp(-log(LEARNING_RATE/LEARNING_RATE_MIN_VALUE)*((double)epoch/TRAINING_EPOCHS));
}

double findWinningNeuron(Network* network, Data data, int* winningNeuron){
    double smallestDistance = 9999;
    for(unsigned int currentNeuronIndex=0; currentNeuronIndex<network->numNeurons; ++currentNeuronIndex){
        Neuron neuron = network->neurons[currentNeuronIndex];
        double distance = sqrt(pow(data.x-neuron.x,2)+pow(data.y-neuron.y,2));
        if(distance<smallestDistance){
            smallestDistance = distance;
            *winningNeuron = currentNeuronIndex;
        }
    }
    return smallestDistance;
}

initializeNetwork(...)启动所有具有-1和1范围内随机权重的神经元。 normalizeData(...)以某种方式规范化,因此最大值为1.

示例: 如果我向网络提供大约50个(规范化的)数据点,这些数据点分为3个群集,剩余的totaldistance将保持在 7.3 。当我检查神经元的位置时,它应该代表群集的中心,两个是完美的,一个位于群集的边界。算法不应该更多地移动到中心吗?我重复了几次算法,输出总是相似的(完全相同的错误的点)

1 个答案:

答案 0 :(得分:1)

你的代码看起来不像LVQ,特别是你没有使用过获胜的神经元,而你只应移动这个

void train(Network* network, Data trainingData[], size_t sizeTrainingData){
    for(int epoch=0; epoch<TRAINING_EPOCHS; ++epoch){
        double learningRate = getLearningRate(epoch);
        double totalDistance = 0;
        for(int i=0; i<sizeTrainingData; ++i){
            Data currentData = trainingData[i];
            int winningNeuron = 0;
            totalDistance += findWinningNeuron(network, currentData, &winningNeuron);
            //update weight
            network->neurons[i].x += learningRate * (currentData.x - network->neurons[i].x);
            network->neurons[i].y += learningRate * (currentData.y - network->neurons[i].y);
        }
        if(totalDistance<MIN_TOTAL_DISTANCE) break;
    }
}

您要移动的神经元位于winningNeuron但您更新i神经元i实际迭代训练样本,我很惊讶您不会跌倒关闭你的记忆(网络 - &gt;神经元应该小于sizeTrainingData)。我想你的意思是

void train(Network* network, Data trainingData[], size_t sizeTrainingData){
    for(int epoch=0; epoch<TRAINING_EPOCHS; ++epoch){
        double learningRate = getLearningRate(epoch);
        double totalDistance = 0;
        for(int i=0; i<sizeTrainingData; ++i){
            Data currentData = trainingData[i];
            int winningNeuron = 0;
            totalDistance += findWinningNeuron(network, currentData, &winningNeuron);
            //update weight
            network->neurons[winningNeuron].x += learningRate * (currentData.x - network->neurons[winningNeuron].x);
            network->neurons[winningNeuron].y += learningRate * (currentData.y - network->neurons[winningNeuron].y);
        }
        if(totalDistance<MIN_TOTAL_DISTANCE) break;
    }
}