我正在尝试用deeplearning4j训练Word2Vec模型。
在发生许多错误之前,一切都很好:
o.d.p.Parallelization - Error occurred processing data
java.lang.IllegalArgumentException: Unable to get linear index >= 10
at org.nd4j.linalg.api.ndarray.BaseNDArray.getDouble(BaseNDArray.java:3287) ~[nd4j-api-0.0.3.5.5.5.jar:na]
at org.deeplearning4j.models.word2vec.VocabWord.getGradient(VocabWord.java:137) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
at org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.iterateSample(InMemoryLookupTable.java:277) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
at org.deeplearning4j.models.word2vec.Word2Vec.iterate(Word2Vec.java:343) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
at org.deeplearning4j.models.word2vec.Word2Vec.skipGram(Word2Vec.java:331) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
培训过程仍在继续,但似乎此错误会影响结果。
版本0.0.3.3.4.alpha2
和0.4-rc1.2
以下是完整代码,在教程tutorial
中实现public void train() throws IOException{
SentenceIterator iter = new LineSentenceIterator(new File(trainSetFileName));
iter.setPreProcessor(new SentencePreProcessor() {
@Override
public String preProcess(String sentence) {
return sentence.toLowerCase();
}
});
final EndingPreProcessor preProcessor = new EndingPreProcessor();
TokenizerFactory tokenizer = new DefaultTokenizerFactory();
tokenizer.setTokenPreProcessor(new TokenPreProcess() {
@Override
public String preProcess(String token) {
token = token.toLowerCase();
String base = preProcessor.preProcess(token);
//base = base.replaceAll("\\d", "d");
return base;
}
});
int batchSize = 1000;
int iterations = 30;
int layerSize = 300;
Word2Vec vec = new Word2Vec.Builder()
.batchSize(batchSize) //# words per minibatch.
.sampling(1e-5) // negative sampling. drops words out
.minWordFrequency(5) //
.useAdaGrad(true) //
.layerSize(layerSize) // word feature vector size
.iterations(iterations) // # iterations to train
.learningRate(0.025) //
.minLearningRate(1e-2) // learning rate decays wrt # words. floor learning
.negativeSample(10) // sample size 10 words
.iterate(iter) //
.tokenizerFactory(tokenizer)
.saveVocab(true)
.workers(3)
.build();
vec.fit();
wordVectors = vec;
WordVectorSerializer.writeWordVectors(vec, outputFileName);
}