(C ++)K-Means聚类麻烦

时间:2015-11-23 00:26:05

标签: c++ algorithm machine-learning computer-vision k-means

我在使用K-means聚类算法时遇到了一些麻烦。输入文件如下所示:

4

60 60

23 45

25 11

30 11

...

...

...

在60x60图像网格中共有4个聚类。该算法似乎有效,但当它重新开始重新计算质心并更改标签时,网格中的标签会慢慢开始变为1。经过大约5次迭代后,所有标签都变为1.我已经根据我的知识反复审查了代码,我无法弄清楚为什么所有标签都只变为1。任何帮助表示赞赏!

#include <iostream>
#include <fstream>
#include <cmath>
using namespace std;

class Node {
public:
    int x;
    int y;
    int label;
    double distace;
    Node* next;

    void printNode() {
        cout << "X:\t" << x << endl;
        cout << "Y:\t" << y << endl;
        cout << "Label:\t" << label << endl;
    }

    Node(int i, int j) {
        x = i;
        y = j;
        next = nullptr;
    }
};

class LinkedList {
public:
    Node* head;
    int length;
    Node* scanner;

    LinkedList() {
        Node* n = new Node(-999, -999);
        head = n;
        scanner = head;
        length = 0;
    }

    void insert(Node* n) {
        if (!head->next) {
            head->next = n;
            length++;
            return;
        }
        n->next = head->next;
        head->next = n;
        length++;
    }

    void deleteNode(Node* n) {
        Node* prev = head;
        Node* current = head->next;
        while (current) {
            if (n->x == current->x && 
                n->y == current->y &&
                n->label == current->label) {
                    prev->next = current->next;
                    return;
                }
                prev = prev->next;
                current = current->next;
        }
        length--;
    }

    void printList() {
        Node* current = head->next;
        while (current) {
            current->printNode();
            cout << endl;
            current = current->next;
        }
    }

    Node* scan() {
        scanner = scanner->next;
        if (!scanner) return 0;
        return scanner;
    }

    void resetScanner() {
        scanner = head;
    }

    void changeLabelTo(int x, int y, int newLabel) {
        Node* current = head->next;
        while (current) {
            if (current->x == x &&
                current->y == y) {
                    current->label = newLabel;
                }

            current = current->next;
        }
    }
};

class KMeans {
public:
    struct xycoord {
      int x;
      int y;
    };

    int k;
    xycoord* kcentroids;
    LinkedList ll;
    int row;
    int col;
    int** image;
    int tracker;

    int getLabel() {
        int x = tracker++;
        if (tracker > k) { tracker = 1; }
        return x;
    }

    KMeans(int clusters, int r, int c) { 
        k = clusters;
        tracker = 1;
        kcentroids = new xycoord[k+1];
        row = r;
        col = c;
        image = new int*[row];
        for (int i = 0; i < row; i++)
            image[i] = new int[col];

    }

    ~KMeans() {
        delete [] kcentroids;
        for (int i = 0; i < row; i++)
            delete [] image[i];
        delete [] image;

        cout << "Called!" << endl;
    }

    void displayImage() {
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < col; j++) {
                if (image[i][j] == 0) cout << " ";
                else { cout << image[i][j]; }
            }
            cout << endl;
        }
    }

    void imageOutput() {
        for (int i = 0; i < ll.length; i++) {
            Node* n = ll.scan();
            if (n) {
                image[n->x][n->y] = n->label;
            }
        }
        displayImage();
    }

    void computeCentroids() {
        for (int i = 1; i <= k; i++) {
            kcentroids[i].x = 0;
            kcentroids[i].y = 0;
        }

        int* count = new int[k+1];
        for (int i  = 0; i < ll.length; i++) {
            Node* n = ll.scan();
            if (n) {
                kcentroids[n->label].x += n->x;
                kcentroids[n->label].y += n->y;
                count[n->label]++;
            }
        }

        cout << endl;

        for (int i = 1; i <= k; i++) {
            kcentroids[i].x = kcentroids[i].x / count[i];
            kcentroids[i].y = kcentroids[i].y / count[i];

            /*
             * i - label
             * 6 - 4
             * 7 - 3
             * 8 - 2
             * 9 - 1
             */
            //image[kcentroids[i].x][kcentroids[i].y] = 10-i;
        }


        delete [] count;
    }

    void computeDistanceAndSetLabels() {
        for (int i = 0; i < ll.length; i++) {
            int minLabel = 0;
            double min = 99999.0;
            Node* n = ll.scan();
            for (int j = 1; j <= k; j++) {
                double m = 0.0;
                // distance formula
                m = sqrt(pow(n->x-kcentroids[j].x, 2) + pow(n->y-kcentroids[j].y,2));
                if (m < min) {
                    min = m;
                    minLabel = j;
                }
            }
            cout << i << " " << n->x << " " << n->y << " " << n->label << " ";
            n->label = minLabel;
            cout << minLabel << " " << n->label <<endl;
        }
    }

    void startClustering() {
        // more than 2 starts showing 1 take over all labels, 
        // this is to be changed to something better however
        for (int i = 0; i < 4; i++) {
            ll.resetScanner();
            imageOutput();
            ll.resetScanner();
            computeCentroids();
            ll.resetScanner();
            computeDistanceAndSetLabels();
            ll.resetScanner();
            imageOutput();
       }
    }
};

Node* createNode(int x, int y, int k) {
    Node* n = new Node(x, y);
    n->label = k;
    return n;
}

int main(int argc, char* argv[]) {
    if (argc < 4) {
        cout << "Please start program with: program in.txt out1.txt out2.txt" << endl;
        return -1;
    }

    ifstream in(argv[1]);
    if (!in) {
        cout << "File: " << argv[1] << " could not be read" << endl;
        return -2;
    }

    int k, rows, cols;
    in >> k;
    in >> rows;
    in >> cols;
    cout << k << " " << rows << " " << cols << endl;

    KMeans km(k, rows, cols);
    km.displayImage();

    LinkedList ll;
    int num;
    while (in >> num) {
        int num2;
        in >> num2;
        ll.insert(createNode(num, num2, km.getLabel()));
    }
    km.ll = ll;
    km.ll.printList();
    km.startClustering();
    km.displayImage();

    return 0;
}

1 个答案:

答案 0 :(得分:1)

您未在computeCentroids()中将计数初始化为零,从而导致未定义的行为。将它们与您的质心一起归零:

int* count = new int[k+1];
for (int i = 1; i <= k; i++) {
    kcentroids[i].x = 0;
    kcentroids[i].y = 0;
    count[i] = 0;
}

当您按照以下方式新建数组时,也可以将零数组归零:

int* count = new int[k+1]();