我想训练mahout进行分类。对我来说,这个文本来自数据库,我真的不想将它们存储到文件进行mahout培训。我检查了MIA源代码并更改了以下代码以进行非常基本的培训任务。 mahout示例的常见问题是它们显示如何使用20个新闻组从cmd提示符使用mahout,或者代码对Hadoop Zookeeper等有很多依赖性。如果有人可以查看我的代码,或者指向我,我将非常感激一个非常简单的教程,展示如何训练模型然后使用它。
截至目前,在以下代码中,我永远不会超过if (best != null)
,因为learningAlgorithm.getBest();
始终返回null!
很抱歉发布了整个代码,但没有看到任何其他选项
public class Classifier {
private static final int FEATURES = 10000;
private static final TextValueEncoder encoder = new TextValueEncoder("body");
private static final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
/**
* @param args the command line arguments
*/
public static void main(String[] args) throws Exception {
int leakType = 0;
// TODO code application logic here
AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
Dictionary newsGroups = new Dictionary();
//ModelDissector md = new ModelDissector();
ListMultimap<String, String> noteBySection = LinkedListMultimap.create();
noteBySection.put("good", "I love this product, the screen is a pleasure to work with and is a great choice for any business");
noteBySection.put("good", "What a product!! Really amazing clarity and works pretty well");
noteBySection.put("good", "This product has good battery life and is a little bit heavy but I like it");
noteBySection.put("bad", "I am really bored with the same UI, this is their 5th version(or fourth or sixth, who knows) and it looks just like the first one");
noteBySection.put("bad", "The phone is bulky and useless");
noteBySection.put("bad", "I wish i had never bought this laptop. It died in the first year and now i am not able to return it");
encoder.setProbes(2);
double step = 0;
int[] bumps = {1, 2, 5};
double averageCorrect = 0;
double averageLL = 0;
int k = 0;
//-------------------------------------
//notes.keySet()
for (String key : noteBySection.keySet()) {
System.out.println(key);
List<String> notes = noteBySection.get(key);
for (Iterator<String> it = notes.iterator(); it.hasNext();) {
String note = it.next();
int actual = newsGroups.intern(key);
Vector v = encodeFeatureVector(note);
learningAlgorithm.train(actual, v);
k++;
int bump = bumps[(int) Math.floor(step) % bumps.length];
int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
double maxBeta;
double nonZeros;
double positive;
double norm;
double lambda = 0;
double mu = 0;
if (best != null) {
CrossFoldLearner state = best.getPayload().getLearner();
averageCorrect = state.percentCorrect();
averageLL = state.logLikelihood();
OnlineLogisticRegression model = state.getModels().get(0);
// finish off pending regularization
model.close();
Matrix beta = model.getBeta();
maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
@Override
public double apply(double v) {
return Math.abs(v) > 1.0e-6 ? 1 : 0;
}
});
positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
@Override
public double apply(double v) {
return v > 0 ? 1 : 0;
}
});
norm = beta.aggregate(Functions.PLUS, Functions.ABS);
lambda = learningAlgorithm.getBest().getMappedParams()[0];
mu = learningAlgorithm.getBest().getMappedParams()[1];
} else {
maxBeta = 0;
nonZeros = 0;
positive = 0;
norm = 0;
}
System.out.println(k % (bump * scale));
if (k % (bump * scale) == 0) {
if (learningAlgorithm.getBest() != null) {
System.out.println("----------------------------");
ModelSerializer.writeBinary("c:/tmp/news-group-" + k + ".model",
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}
step += 0.25;
System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
System.out.printf("%d\t%.3f\t%.2f\t%s\n",
k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
}
}
}
learningAlgorithm.close();
}
private static Vector encodeFeatureVector(String text) {
encoder.addText(text.toLowerCase());
//System.out.println(encoder.asString(text));
Vector v = new RandomAccessSparseVector(FEATURES);
bias.addToVector((byte[]) null, 1, v);
encoder.flush(1, v);
return v;
}
}
答案 0 :(得分:2)
您需要正确地将这些单词添加到您的要素向量中。它看起来像以下代码:
bias.addToVector((byte[]) null, 1, v);
没有做你期望的事。它只是将空字节添加到权重为1的特征向量中。
您正在调用WordValueEncoder.addToVector(byte[] originalForm, double w, Vector data)
方法的包装器。
确保在音符贴图值中循环显示单词值,并相应地将它们添加到特征向量中。
答案 1 :(得分:0)
我强烈建议您也将问题转发给Mahout邮件列表中非常好的人https://mahout.apache.org/general/mailing-lists,-irc-and-archives.html
答案 2 :(得分:0)
这发生在我今天早些时候。我看到你的初始样本很少,因为你正在玩像我一样的代码。我的问题是,由于这个算法是一个自适应算法,我需要设置“适应”的间隔和窗口非常低,否则它永远不会找到一个新的最佳模型:
learningAlgorithm.setInterval(1);
learningAlgorithm.setAveragingWindow(1);
这样,算法可以在它看到的每1个向量之后被强制“适应”,这将是至关重要的,因为你的示例代码只有6个向量。