我正在为数据挖掘编写apriori算法的代码我的代码需要长达60秒才能获得一个非常小的数据集,这个数据集只需2秒即可通过互联网获得的其他代码解决但我不知道我在哪里做错了,有人可以告诉我为什么其他代码快速超过我的。
我的代码:
import java.util.*;
import java.io.*;
public class Apriori_p {
double support;
ArrayList<String> trans;
Map<String, Integer> map;
long start;
void print(ArrayList<String> temp) {
for (int i = 0; i < temp.size(); i++) {
System.out.println(temp.get(i));
}
System.out.println("Count :" + temp.size());
}
void run() throws FileNotFoundException {
start = System.currentTimeMillis();
trans = new ArrayList<>();
ArrayList<String> temp = new ArrayList<>();
map = new HashMap<>();
Scanner sc = new Scanner(System.in);
System.out.println("Enter support %");
support = sc.nextDouble();
System.out.println("Enter file name");
String file = sc.next();
sc = new Scanner(new File(file));
int lines = 0;
while (sc.hasNextLine()) {
String s = sc.nextLine();
if (s.matches("\\s*")) {
continue;
}
lines++;
String[] spl = s.split("\\s+");
ArrayList<Integer> elem = new ArrayList<>();
for (int i = 0; i < spl.length; i++) {
String cand;
int n = Integer.parseInt(spl[i]);
cand = spl[i].trim();
if (!elem.contains(n)) {
elem.add(n);
}
if (map.containsKey(cand)) {
int count = map.get(cand);
map.put(cand, count + 1);
} else {
map.put(cand, 1);
}
}
Collections.sort(elem);
String con = " ";
for (int i = 0; i < elem.size(); i++) {
con = con + elem.get(i) + " ";
String s1 = String.valueOf(elem.get(i)).trim();
if(!temp.contains(s1))
temp.add(s1);
}
trans.add(con);
}
support = (support * lines) / 100;
System.out.println(System.currentTimeMillis() - start);
apriori(temp, 1);
}
public static void main(String[] args) throws FileNotFoundException {
new Apriori_p().run();
}
public void apriori(ArrayList<String> temp, int m) {
Set<String> diff = null;
if (m == 1) {
diff = new HashSet<>();
}
for (int i = 0; i < temp.size(); i++) {
if (map.get(temp.get(i)) < support) {
if (m == 1) {
diff.add(temp.get(i));
}
temp.remove(i);
i--;
}
}
for (int i = 0; i < trans.size() && m == 1; i++) {
for (String j : diff) {
String rep = " " + j + " ";
trans.get(i).replace(rep, " ");
}
}
if (temp.size() == 0) {
return;
}
System.out.println("Size " + m + " :");
print(temp);
ArrayList<String> ntemp = new ArrayList<>();
int n = temp.size();
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
StringTokenizer st1 = new StringTokenizer(temp.get(i), " ");
StringTokenizer st2 = new StringTokenizer(temp.get(j), " ");
String str1 = "", str2 = "";
for (int s = 0; s < m - 2; s++) {
str1 = str1 + " " + st1.nextToken();
str2 = str2 + " " + st2.nextToken();
}
if (str2.compareToIgnoreCase(str1) == 0) {
int s1 = Integer.parseInt(st1.nextToken()), s2 = Integer.parseInt(st2.nextToken());
String s3;
if (s1 <= s2) {
s3 = (str1 + " " + s1 + " " + s2).trim();
} else {
s3 = (str1 + " " + s2 + " " + s1).trim();
}
if(!ntemp.contains(s3)){
ntemp.add(s3);
}
}
}
}
temp.clear();
for (int j = 0; j < ntemp.size(); j++) {
int c = 0;
for (int i = 0; i < trans.size(); i++) {
int check = 0;
String tr = trans.get(i);
StringTokenizer st1 = new StringTokenizer(ntemp.get(j)," ");
while(st1.hasMoreElements()){
String str = st1.nextToken();
if(!tr.contains(" " + str + " ")){
check = 1;
break;
}
}
if(check == 0){
c= 1;
if (map.containsKey(ntemp.get(j))) {
int count = map.get(ntemp.get(j));
map.put(ntemp.get(j), count + 1);
} else {
map.put(ntemp.get(j), 1);
}
}
}
if (c == 0) {
ntemp.remove(j);
j--;
}
}
apriori(ntemp, m + 1);
}
}
快速代码:
import java.io.*;
import java.util.*;
public class Apriori3{
public static void main(String[] args) throws Exception {
Apriori3 ap = new Apriori3(args);
}
private List<int[]> itemsets;
private String transaFile;
private int numItems;
private int numTransactions;
private double minSup;
private boolean usedAsLibrary = false;
public Apriori3(String[] args) throws Exception {
configure(args);
go();
}
private void go() throws Exception {
long start = System.currentTimeMillis();
createItemsetsOfSize1();
int itemsetNumber = 1;
int nbFrequentSets = 0;
while (itemsets.size() > 0) {
calculateFrequentItemsets();
if (itemsets.size() != 0) {
nbFrequentSets += itemsets.size();
log("Found " + itemsets.size() + " frequent itemsets of size " + itemsetNumber + " (with support " + (minSup * 100) + "%)");;
createNewItemsetsFromPreviousOnes();
}
itemsetNumber++;
}
long end = System.currentTimeMillis();
log("Execution time is: " + ((double) (end - start) / 1000) + " seconds.");
log("Found " + nbFrequentSets + " frequents sets for support " + (minSup * 100) + "% (absolute " + Math.round(numTransactions * minSup) + ")");
log("Done");
}
private void foundFrequentItemSet(int[] itemset, int support) {
if (usedAsLibrary) {
} else {
System.out.println(Arrays.toString(itemset) + " (" + ((support / (double) numTransactions)) + " " + support + ")");
}
}
private void log(String message) {
if (!usedAsLibrary) {
System.err.println(message);
}
}
private void configure(String[] args) throws Exception {
if (args.length != 0) {
transaFile = args[0];
} else {
transaFile = "chess.dat"; // default
}
if (args.length >= 2) {
minSup = (Double.valueOf(args[1]).doubleValue());
} else {
minSup = .8;
}
if (minSup > 1 || minSup < 0) {
throw new Exception("minSup: bad value");
}
numItems = 0;
numTransactions = 0;
BufferedReader data_in = new BufferedReader(new FileReader(transaFile));
while (data_in.ready()) {
String line = data_in.readLine();
if (line.matches("\\s*")) {
continue;
}
numTransactions++;
StringTokenizer t = new StringTokenizer(line, " ");
while (t.hasMoreTokens()) {
int x = Integer.parseInt(t.nextToken());
if (x + 1 > numItems) {
numItems = x + 1;
}
}
}
outputConfig();
}
private void outputConfig() {
log("Input configuration: " + numItems + " items, " + numTransactions + " transactions, ");
log("minsup = " + minSup + "%");
}
private void createItemsetsOfSize1() {
itemsets = new ArrayList<int[]>();
for (int i = 0; i < numItems; i++) {
int[] cand = {i};
itemsets.add(cand);
}
}
private void createNewItemsetsFromPreviousOnes() {
int currentSizeOfItemsets = itemsets.get(0).length;
log("Creating itemsets of size " + (currentSizeOfItemsets + 1) + " based on " + itemsets.size() + " itemsets of size " + currentSizeOfItemsets);
HashMap<String, int[]> tempCandidates = new HashMap<String, int[]>(); //temporary candidates
for (int i = 0; i < itemsets.size(); i++) {
for (int j = i + 1; j < itemsets.size(); j++) {
int[] X = itemsets.get(i);
int[] Y = itemsets.get(j);
assert (X.length == Y.length);
int[] newCand = new int[currentSizeOfItemsets + 1];
for (int s = 0; s < newCand.length - 1; s++) {
newCand[s] = X[s];
}
int ndifferent = 0;
for (int s1 = 0; s1 < Y.length; s1++) {
boolean found = false;
for (int s2 = 0; s2 < X.length; s2++) {
if (X[s2] == Y[s1]) {
found = true;
break;
}
}
if (!found) {
ndifferent++;
newCand[newCand.length - 1] = Y[s1];
}
}
assert (ndifferent > 0);
if (ndifferent == 1) {
Arrays.sort(newCand);
tempCandidates.put(Arrays.toString(newCand), newCand);
}
}
}
itemsets = new ArrayList<int[]>(tempCandidates.values());
log("Created " + itemsets.size() + " unique itemsets of size " + (currentSizeOfItemsets + 1));
}
private void line2booleanArray(String line, boolean[] trans) {
Arrays.fill(trans, false);
StringTokenizer stFile = new StringTokenizer(line, " ");
while (stFile.hasMoreTokens()) {
int parsedVal = Integer.parseInt(stFile.nextToken());
trans[parsedVal] = true;
}
}
private void calculateFrequentItemsets() throws Exception {
log("Passing through the data to compute the frequency of " + itemsets.size() + " itemsets of size " + itemsets.get(0).length);
List<int[]> frequentCandidates = new ArrayList<int[]>();
boolean match;
int count[] = new int[itemsets.size()];
BufferedReader data_in = new BufferedReader(new InputStreamReader(new FileInputStream(transaFile)));
boolean[] trans = new boolean[numItems];
for (int i = 0; i < numTransactions; i++) {
String line = data_in.readLine();
line2booleanArray(line, trans);
for (int c = 0; c < itemsets.size(); c++) {
match = true;
int[] cand = itemsets.get(c);
for (int xx : cand) {
if (trans[xx] == false) {
match = false;
break;
}
}
if (match) {
count[c]++;
}
}
}
data_in.close();
for (int i = 0; i < itemsets.size(); i++) {
if ((count[i] / (double) (numTransactions)) >= minSup) {
foundFrequentItemSet(itemsets.get(i), count[i]);
frequentCandidates.add(itemsets.get(i));
}
}
itemsets = frequentCandidates;
}
}