我正在通过Java使用Mallet,而我无法弄清楚如何针对我训练过的现有主题模型评估新文档。
我生成模型的初始代码与Mallett Developers Guide for Topic Modelling中的代码非常相似,之后我只是将模型保存为Java对象。在稍后的过程中,我从文件重新加载该Java对象,通过.addInstances()
添加新实例,然后根据原始训练集中找到的主题仅评估这些新实例。
This stats.SE thread提供了一些高级别的建议,但我看不出如何将它们用于Mallet框架。
任何帮助都非常感激。
答案 0 :(得分:3)
我发现答案隐藏在slide-deck from Mallet's lead developer:
中TopicInferencer inferencer = model.getInferencer();
double[] topicProbs = inferencer.getSampledDistribution(newInstance, 100, 10, 10);
答案 1 :(得分:3)
推理实际上也列在问题中提供的example link中(最后几行)。
对于任何对保存/加载训练模型的整个代码感兴趣的人,然后使用它来推断新文档的模型分布 - 这里有一些片段:
model.estimate()
完成后,您拥有实际训练的模型,因此您可以使用标准Java ObjectOutputStream
对其进行序列化(因为ParallelTopicModel
实现Serializable
):
try {
FileOutputStream outFile = new FileOutputStream("model.ser");
ObjectOutputStream oos = new ObjectOutputStream(outFile);
oos.writeObject(model);
oos.close();
} catch (FileNotFoundException ex) {
// handle this error
} catch (IOException ex) {
// handle this error
}
但是请注意,当你推断时,你还需要通过相同的管道传递新的句子(如Instance
)以便预处理它(tokenzie等),因此,你还需要保存管道-list(因为我们在创建实例然后序列化时使用SerialPipe
):
// initialize the pipelist (using in model training)
SerialPipes pipes = new SerialPipes(pipeList);
try {
FileOutputStream outFile = new FileOutputStream("pipes.ser");
ObjectOutputStream oos = new ObjectOutputStream(outFile);
oos.writeObject(pipes);
oos.close();
} catch (FileNotFoundException ex) {
// handle error
} catch (IOException ex) {
// handle error
}
为了加载模型/管道并将它们用于推理,我们需要反序列化:
private static void InferByModel(String sentence) {
// define model and pipeline
ParallelTopicModel model = null;
SerialPipes pipes = null;
// load the model
try {
FileInputStream outFile = new FileInputStream("model.ser");
ObjectInputStream oos = new ObjectInputStream(outFile);
model = (ParallelTopicModel) oos.readObject();
} catch (IOException ex) {
System.out.println("Could not read model from file: " + ex);
} catch (ClassNotFoundException ex) {
System.out.println("Could not load the model: " + ex);
}
// load the pipeline
try {
FileInputStream outFile = new FileInputStream("pipes.ser");
ObjectInputStream oos = new ObjectInputStream(outFile);
pipes = (SerialPipes) oos.readObject();
} catch (IOException ex) {
System.out.println("Could not read pipes from file: " + ex);
} catch (ClassNotFoundException ex) {
System.out.println("Could not load the pipes: " + ex);
}
// if both are properly loaded
if (model != null && pipes != null){
// Create a new instance named "test instance" with empty target
// and source fields note we are using the pipes list here
InstanceList testing = new InstanceList(pipes);
testing.addThruPipe(
new Instance(sentence, null, "test instance", null));
// here we get an inferencer from our loaded model and use it
TopicInferencer inferencer = model.getInferencer();
double[] testProbabilities = inferencer
.getSampledDistribution(testing.get(0), 10, 1, 5);
System.out.println("0\t" + testProbabilities[0]);
}
}
由于某种原因,我没有得到与原始模型完全相同的推断 - 但这是另一个问题的问题(如果有人知道,我会很高兴听到)