我正在构建一个必须阅读大量文本文档的分类器,但我发现我的countWordFrequenties
方法在处理的文档越多越慢。下面这个方法需要60ms(在我的电脑上),而阅读,规范化,标记化,更新我的词汇和均衡不同的整数列表只需要3-5ms(在我的电脑上)。我的countWordFrequencies
方法如下:
public List<Integer> countWordFrequencies(String[] tokens)
{
List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
int counter = 0;
for (int i = 0; i < vocabulary.size(); i++)
{
for (int j = 0; j < tokens.length; j++)
if (tokens[j].equals(vocabulary.get(i)))
counter++;
wordFreqs.add(i, counter);
counter = 0;
}
return wordFreqs;
}
加快这个过程的最佳方法是什么?这种方法有什么问题?
这是我的整个班级,还有另一个班级类别,在这里发布这个也不错,或者你们不需要吗?
public class BayesianClassifier
{
private Map<String,Integer> vocabularyWordFrequencies;
private List<String> vocabulary;
private List<Category> categories;
private List<Integer> wordFrequencies;
private int trainTextAmount;
private int testTextAmount;
private GUI gui;
public BayesianClassifier()
{
this.vocabulary = new ArrayList<>();
this.categories = new ArrayList<>();
this.wordFrequencies = new ArrayList<>();
this.trainTextAmount = 0;
this.gui = new GUI(this);
this.testTextAmount = 0;
}
public List<Category> getCategories()
{
return categories;
}
public List<String> getVocabulary()
{
return this.vocabulary;
}
public List<Integer> getWordFrequencies()
{
return wordFrequencies;
}
public int getTextAmount()
{
return testTextAmount + trainTextAmount;
}
public void updateWordFrequency(int index, Integer frequency)
{
equalizeIntList(wordFrequencies);
this.wordFrequencies.set(index, wordFrequencies.get(index) + frequency);
}
public String readText(String path)
{
BufferedReader br;
String result = "";
try
{
br = new BufferedReader(new FileReader(path));
StringBuilder sb = new StringBuilder();
String line = br.readLine();
while (line != null)
{
sb.append(line);
sb.append("\n");
line = br.readLine();
}
result = sb.toString();
br.close();
}
catch (IOException e)
{
e.printStackTrace();
}
return result;
}
public String normalizeText(String text)
{
String fstNormalized = Normalizer.normalize(text, Normalizer.Form.NFD);
fstNormalized = fstNormalized.replaceAll("[^\\p{ASCII}]","");
fstNormalized = fstNormalized.toLowerCase();
fstNormalized = fstNormalized.replace("\n","");
fstNormalized = fstNormalized.replaceAll("[0-9]","");
fstNormalized = fstNormalized.replaceAll("[/()!?;:,.%-]","");
fstNormalized = fstNormalized.trim().replaceAll(" +", " ");
return fstNormalized;
}
public String[] handleText(String path)
{
String text = readText(path);
String normalizedText = normalizeText(text);
return tokenizeText(normalizedText);
}
public void createCategory(String name, BayesianClassifier bc)
{
Category newCategory = new Category(name, bc);
categories.add(newCategory);
}
public List<String> updateVocabulary(String[] tokens)
{
for (int i = 0; i < tokens.length; i++)
if (!vocabulary.contains(tokens[i]))
vocabulary.add(tokens[i]);
return vocabulary;
}
public List<Integer> countWordFrequencies(String[] tokens)
{
List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
int counter = 0;
for (int i = 0; i < vocabulary.size(); i++)
{
for (int j = 0; j < tokens.length; j++)
if (tokens[j].equals(vocabulary.get(i)))
counter++;
wordFreqs.add(i, counter);
counter = 0;
}
return wordFreqs;
}
public String[] tokenizeText(String normalizedText)
{
return normalizedText.split(" ");
}
public void handleTrainDirectory(String folderPath, Category category)
{
File folder = new File(folderPath);
File[] listOfFiles = folder.listFiles();
if (listOfFiles != null)
{
for (File file : listOfFiles)
{
if (file.isFile())
{
handleTrainText(file.getPath(), category);
}
}
}
else
{
System.out.println("There are no files in the given folder" + " " + folderPath.toString());
}
}
public void handleTrainText(String path, Category category)
{
long startTime = System.currentTimeMillis();
trainTextAmount++;
String[] text = handleText(path);
updateVocabulary(text);
equalizeAllLists();
List<Integer> wordFrequencies = countWordFrequencies(text);
long finishTime = System.currentTimeMillis();
System.out.println("That took 1: " + (finishTime-startTime)+ " ms");
long startTime2 = System.currentTimeMillis();
category.update(wordFrequencies);
updatePriors();
long finishTime2 = System.currentTimeMillis();
System.out.println("That took 2: " + (finishTime2-startTime2)+ " ms");
}
public void handleTestText(String path)
{
testTextAmount++;
String[] text = handleText(path);
List<Integer> wordFrequencies = countWordFrequencies(text);
Category category = guessCategory(wordFrequencies);
boolean correct = gui.askFeedback(path, category);
if (correct)
{
category.update(wordFrequencies);
updatePriors();
System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
}
else
{
Category correctCategory = gui.askCategory();
correctCategory.update(wordFrequencies);
updatePriors();
System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
}
}
public void updatePriors()
{
for (Category category : categories)
{
category.updatePrior();
}
}
public Category guessCategory(List<Integer> wordFrequencies)
{
List<Double> chances = new ArrayList<>();
for (int i = 0; i < categories.size(); i++)
{
double chance = categories.get(i).getPrior();
System.out.println("The prior is:" + chance);
for(int j = 0; j < wordFrequencies.size(); j++)
{
chance = chance * categories.get(i).getWordProbabilities().get(j);
}
chances.add(chance);
}
double max = getMaxValue(chances);
int index = chances.indexOf(max);
System.out.println(max);
System.out.println(index);
return categories.get(index);
}
public double getMaxValue(List<Double> values)
{
Double max = 0.0;
for (Double dubbel : values)
{
if(dubbel > max)
{
max = dubbel;
}
}
return max;
}
public void equalizeAllLists()
{
for(Category category : categories)
{
if (category.getWordFrequencies().size() < vocabulary.size())
{
category.setWordFrequencies(equalizeIntList(category.getWordFrequencies()));
}
}
for(Category category : categories)
{
if (category.getWordProbabilities().size() < vocabulary.size())
{
category.setWordProbabilities(equalizeDoubleList(category.getWordProbabilities()));
}
}
}
public List<Integer> equalizeIntList(List<Integer> list)
{
while (list.size() < vocabulary.size())
{
list.add(0);
}
return list;
}
public List<Double> equalizeDoubleList(List<Double> list)
{
while (list.size() < vocabulary.size())
{
list.add(0.0);
}
return list;
}
public void selectFeatures()
{
for(int i = 0; i < wordFrequencies.size(); i++)
{
if(wordFrequencies.get(i) < 2)
{
vocabulary.remove(i);
wordFrequencies.remove(i);
for(Category category : categories)
{
category.removeFrequency(i);
}
}
}
}
}
答案 0 :(得分:11)
您的方法有O(n*m)
运行时间(n是词汇量大小,而m是令牌大小)。使用散列,这可以减少到O(m)
,这显然更好。
for (String token: tokens) {
if(!map.containsKey(token)){
map.put(token,0);
}
map.put(token,map.get(token)+1);
}
答案 1 :(得分:3)
使用Map
可以大大提高性能,正如Sleiman Jneidi在他的回答中所建议的那样。但是,使用Java 8的流API可以做得更好:
Map<String, Long> frequencies =
Arrays.stream(tokens)
.collect(Collectors.groupingBy(Function.identity(),
Collectors.counting()));
答案 2 :(得分:2)
我没有使用词汇表列表和频率列表,而是使用将存储word-&gt;频率的Map。这样你就可以避免双重循环,这在我看来是什么会杀死你的表现。
public Map<String,Integer> countWordFrequencies(String[] tokens) {
// vocabulary is Map<String,Integer> initialized with all words as keys and 0 as value
for (String word: tokens)
if (vocabulary.containsKey(word)) {
vocabulary.put(word, vocabulary.get(word)+1);
}
return vocabulary;
}
答案 3 :(得分:2)
如果您不想使用Java 8,可以尝试使用guava中的MultiSet