Deeplearning4j Word2Vec:无法获得线性索引异常

时间:2015-09-09 12:56:17

标签: java nlp deep-learning word2vec

我正在尝试用deeplearning4j训练Word2Vec模型。

在发生许多错误之前,一切都很好:

o.d.p.Parallelization - Error occurred processing data
java.lang.IllegalArgumentException: Unable to get linear index >= 10
    at org.nd4j.linalg.api.ndarray.BaseNDArray.getDouble(BaseNDArray.java:3287) ~[nd4j-api-0.0.3.5.5.5.jar:na]
    at org.deeplearning4j.models.word2vec.VocabWord.getGradient(VocabWord.java:137) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
    at org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.iterateSample(InMemoryLookupTable.java:277) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
    at org.deeplearning4j.models.word2vec.Word2Vec.iterate(Word2Vec.java:343) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
    at org.deeplearning4j.models.word2vec.Word2Vec.skipGram(Word2Vec.java:331) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]

培训过程仍在继续,但似乎此错误会影响结果。

版本0.0.3.3.4.alpha20.4-rc1.2

仍然存在错误

以下是完整代码,在教程tutorial

中实现
public void train() throws IOException{
        SentenceIterator iter = new LineSentenceIterator(new File(trainSetFileName));
        iter.setPreProcessor(new SentencePreProcessor() {
            @Override
            public String preProcess(String sentence) {
                return sentence.toLowerCase();
            }
        });

        final EndingPreProcessor preProcessor = new EndingPreProcessor();
        TokenizerFactory tokenizer = new DefaultTokenizerFactory();
        tokenizer.setTokenPreProcessor(new TokenPreProcess() {
            @Override
            public String preProcess(String token) {
                token = token.toLowerCase();
                String base = preProcessor.preProcess(token);
                //base = base.replaceAll("\\d", "d");
                return base;
            }
        });

        int batchSize = 1000;
        int iterations = 30;
        int layerSize = 300;

        Word2Vec vec = new Word2Vec.Builder()
                .batchSize(batchSize) //# words per minibatch. 
                .sampling(1e-5) // negative sampling. drops words out
                .minWordFrequency(5) // 
                .useAdaGrad(true) //
                .layerSize(layerSize) // word feature vector size
                .iterations(iterations) // # iterations to train
                .learningRate(0.025) // 
                .minLearningRate(1e-2) // learning rate decays wrt # words. floor learning
                .negativeSample(10) // sample size 10 words
                .iterate(iter) //
                .tokenizerFactory(tokenizer)
                .saveVocab(true)
                .workers(3)
                .build();
        vec.fit();

        wordVectors = vec;

        WordVectorSerializer.writeWordVectors(vec, outputFileName);
    }

0 个答案:

没有答案