优化apriori算法代码

时间:2017-02-14 14:23:09

标签: algorithm optimization data-mining

我正在为数据挖掘编写apriori算法的代码我的代码需要长达60秒才能获得一个非常小的数据集,这个数据集只需2秒即可通过互联网获得的其他代码解决但我不知道我在哪里做错了,有人可以告诉我为什么其他代码快速超过我的。

我的代码:

import java.util.*;
import java.io.*;

public class Apriori_p {

double support;
ArrayList<String> trans;
Map<String, Integer> map;
  long start;

void print(ArrayList<String> temp) {

    for (int i = 0; i < temp.size(); i++) {
        System.out.println(temp.get(i));
    }
    System.out.println("Count :" + temp.size());
}

void run() throws FileNotFoundException {

      start = System.currentTimeMillis();
    trans = new ArrayList<>();
    ArrayList<String> temp = new ArrayList<>();
    map = new HashMap<>();
    Scanner sc = new Scanner(System.in);
    System.out.println("Enter support %");
    support = sc.nextDouble();

    System.out.println("Enter file name");
    String file = sc.next();

    sc = new Scanner(new File(file));
    int lines = 0;
    while (sc.hasNextLine()) {

        String s = sc.nextLine();
        if (s.matches("\\s*")) {
            continue;
        }
        lines++;

        String[] spl = s.split("\\s+");
        ArrayList<Integer> elem = new ArrayList<>();

        for (int i = 0; i < spl.length; i++) {

            String cand;
            int n = Integer.parseInt(spl[i]);

            cand = spl[i].trim();

            if (!elem.contains(n)) {
                elem.add(n);
            }

            if (map.containsKey(cand)) {
                int count = map.get(cand);
                map.put(cand, count + 1);
            } else {
                map.put(cand, 1);
            }
        }

        Collections.sort(elem);
        String con = " ";
        for (int i = 0; i < elem.size(); i++) {
            con = con + elem.get(i) + " ";
            String s1 = String.valueOf(elem.get(i)).trim();
            if(!temp.contains(s1))
                temp.add(s1);
        }

        trans.add(con);
    }

    support = (support * lines) / 100;
    System.out.println(System.currentTimeMillis() - start);
    apriori(temp, 1);

}

public static void main(String[] args) throws FileNotFoundException {

    new Apriori_p().run();
}

public void apriori(ArrayList<String> temp, int m) {

    Set<String> diff = null;

    if (m == 1) {
        diff = new HashSet<>();
    }

    for (int i = 0; i < temp.size(); i++) {
        if (map.get(temp.get(i)) < support) {

            if (m == 1) {
                diff.add(temp.get(i));
            }

            temp.remove(i);
            i--;
        }
    }

    for (int i = 0; i < trans.size() && m == 1; i++) {

        for (String j : diff) {
            String rep = " " + j + " ";
            trans.get(i).replace(rep, " ");
        }

    }

    if (temp.size() == 0) {
        return;
    }

    System.out.println("Size " + m + " :");
    print(temp);

    ArrayList<String> ntemp = new ArrayList<>();

    int n = temp.size();

    for (int i = 0; i < n; i++) {
        for (int j = i + 1; j < n; j++) {

            StringTokenizer st1 = new StringTokenizer(temp.get(i), " ");
            StringTokenizer st2 = new StringTokenizer(temp.get(j), " ");

            String str1 = "", str2 = "";

            for (int s = 0; s < m - 2; s++) {
                str1 = str1 + " " + st1.nextToken();
                str2 = str2 + " " + st2.nextToken();
            }

            if (str2.compareToIgnoreCase(str1) == 0) {

                int s1 = Integer.parseInt(st1.nextToken()), s2 = Integer.parseInt(st2.nextToken());
                String s3;    

                if (s1 <= s2) {
                   s3 = (str1 + " " + s1 + " " + s2).trim();
                } else {
                   s3 = (str1 + " " + s2 + " " + s1).trim();
                }

                if(!ntemp.contains(s3)){
                    ntemp.add(s3);
                }

            }
        }
    }
    temp.clear();

    for (int j = 0; j < ntemp.size(); j++) {

        int c = 0;
        for (int i = 0; i < trans.size(); i++) {

            int check = 0;
            String tr = trans.get(i);

            StringTokenizer st1 = new StringTokenizer(ntemp.get(j)," ");

            while(st1.hasMoreElements()){
                String str = st1.nextToken();
                if(!tr.contains(" " + str + " ")){
                    check = 1;
                    break;
                }
            }

            if(check == 0){
                c= 1;
                if (map.containsKey(ntemp.get(j))) {
                    int count = map.get(ntemp.get(j));
                    map.put(ntemp.get(j), count + 1);
                } else {
                    map.put(ntemp.get(j), 1);
                }
            }

        }
        if (c == 0) {
            ntemp.remove(j);
            j--;
        }
    }

    apriori(ntemp, m + 1);

}
}

快速代码:

import java.io.*;
import java.util.*;

public class Apriori3{

public static void main(String[] args) throws Exception {
    Apriori3 ap = new Apriori3(args);
}

private List<int[]> itemsets;
private String transaFile;
private int numItems;
private int numTransactions;
private double minSup;

private boolean usedAsLibrary = false;

public Apriori3(String[] args) throws Exception {
    configure(args);
    go();
}

private void go() throws Exception {
    long start = System.currentTimeMillis();

    createItemsetsOfSize1();
    int itemsetNumber = 1;
    int nbFrequentSets = 0;

    while (itemsets.size() > 0) {
        calculateFrequentItemsets();

        if (itemsets.size() != 0) {
            nbFrequentSets += itemsets.size();
            log("Found " + itemsets.size() + " frequent itemsets of size " + itemsetNumber + " (with support " + (minSup * 100) + "%)");;
            createNewItemsetsFromPreviousOnes();
        }

        itemsetNumber++;
    }

    long end = System.currentTimeMillis();
    log("Execution time is: " + ((double) (end - start) / 1000) + " seconds.");
    log("Found " + nbFrequentSets + " frequents sets for support " + (minSup * 100) + "% (absolute " + Math.round(numTransactions * minSup) + ")");
    log("Done");
}

private void foundFrequentItemSet(int[] itemset, int support) {
    if (usedAsLibrary) {
    } else {
        System.out.println(Arrays.toString(itemset) + "  (" + ((support / (double) numTransactions)) + " " + support + ")");
    }
}

private void log(String message) {
    if (!usedAsLibrary) {
        System.err.println(message);
    }
}

private void configure(String[] args) throws Exception {
    if (args.length != 0) {
        transaFile = args[0];
    } else {
        transaFile = "chess.dat"; // default
    }
    if (args.length >= 2) {
        minSup = (Double.valueOf(args[1]).doubleValue());
    } else {
        minSup = .8;
    }
    if (minSup > 1 || minSup < 0) {
        throw new Exception("minSup: bad value");
    }

    numItems = 0;
    numTransactions = 0;
    BufferedReader data_in = new BufferedReader(new FileReader(transaFile));
    while (data_in.ready()) {
        String line = data_in.readLine();
        if (line.matches("\\s*")) {
            continue;
        }
        numTransactions++;
        StringTokenizer t = new StringTokenizer(line, " ");
        while (t.hasMoreTokens()) {
            int x = Integer.parseInt(t.nextToken());
            if (x + 1 > numItems) {
                numItems = x + 1;
            }
        }
    }

    outputConfig();

}

private void outputConfig() {
    log("Input configuration: " + numItems + " items, " + numTransactions + " transactions, ");
    log("minsup = " + minSup + "%");
}

private void createItemsetsOfSize1() {
    itemsets = new ArrayList<int[]>();
    for (int i = 0; i < numItems; i++) {
        int[] cand = {i};
        itemsets.add(cand);
    }
}

private void createNewItemsetsFromPreviousOnes() {
    int currentSizeOfItemsets = itemsets.get(0).length;
    log("Creating itemsets of size " + (currentSizeOfItemsets + 1) + " based on " + itemsets.size() + " itemsets of size " + currentSizeOfItemsets);

    HashMap<String, int[]> tempCandidates = new HashMap<String, int[]>(); //temporary candidates

    for (int i = 0; i < itemsets.size(); i++) {
        for (int j = i + 1; j < itemsets.size(); j++) {
            int[] X = itemsets.get(i);
            int[] Y = itemsets.get(j);

            assert (X.length == Y.length);

            int[] newCand = new int[currentSizeOfItemsets + 1];
            for (int s = 0; s < newCand.length - 1; s++) {
                newCand[s] = X[s];
            }

            int ndifferent = 0;
            for (int s1 = 0; s1 < Y.length; s1++) {
                boolean found = false;
                for (int s2 = 0; s2 < X.length; s2++) {
                    if (X[s2] == Y[s1]) {
                        found = true;
                        break;
                    }
                }
                if (!found) {
                    ndifferent++;
                    newCand[newCand.length - 1] = Y[s1];
                }
            }
            assert (ndifferent > 0);
            if (ndifferent == 1) {
                Arrays.sort(newCand);
                tempCandidates.put(Arrays.toString(newCand), newCand);
            }
        }
    }
    itemsets = new ArrayList<int[]>(tempCandidates.values());
    log("Created " + itemsets.size() + " unique itemsets of size " + (currentSizeOfItemsets + 1));
}

private void line2booleanArray(String line, boolean[] trans) {
    Arrays.fill(trans, false);
    StringTokenizer stFile = new StringTokenizer(line, " ");
    while (stFile.hasMoreTokens()) {
        int parsedVal = Integer.parseInt(stFile.nextToken());
        trans[parsedVal] = true;
    }
}

private void calculateFrequentItemsets() throws Exception {

    log("Passing through the data to compute the frequency of " + itemsets.size() + " itemsets of size " + itemsets.get(0).length);
    List<int[]> frequentCandidates = new ArrayList<int[]>();

    boolean match;
    int count[] = new int[itemsets.size()];

    BufferedReader data_in = new BufferedReader(new InputStreamReader(new FileInputStream(transaFile)));

    boolean[] trans = new boolean[numItems];

    for (int i = 0; i < numTransactions; i++) {

        String line = data_in.readLine();
        line2booleanArray(line, trans);

        for (int c = 0; c < itemsets.size(); c++) {
            match = true;

            int[] cand = itemsets.get(c);

            for (int xx : cand) {
                if (trans[xx] == false) {
                    match = false;
                    break;
                }
            }
            if (match) {
                count[c]++;
            }
        }
    }
    data_in.close();
    for (int i = 0; i < itemsets.size(); i++) {

        if ((count[i] / (double) (numTransactions)) >= minSup) {
            foundFrequentItemSet(itemsets.get(i), count[i]);
            frequentCandidates.add(itemsets.get(i));
        }
    }
    itemsets = frequentCandidates;
}
}

0 个答案:

没有答案