所以我试图实现名为deeplearning4j的Java免费深度学习库来解决nlp中的分类任务。
public static void Learn(String labelledDataFileName) throws Exception {
ParagraphVectors paragraphVectors = new ParagraphVectors();
InMemoryLookupCache cache = new InMemoryLookupCache();
LabelleDataIterator iterator = new LabelleDataIterator(new File(labelledDataFileName));
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
paragraphVectors = new ParagraphVectors.Builder()
.minWordFrequency(1)
.iterations(3)
.learningRate(0.025)
.minLearningRate(0.001)
.layerSize(400)
.batchSize(1000)
.epochs(1)
.iterate(iterator)
.trainWordVectors(true)
.vocabCache(cache)
.tokenizerFactory(t)
.build();
paragraphVectors.fit();
WordVectorSerializer.writeFullModel(paragraphVectors, MODEL_FILE_NAME);
}
非常标准,与网上提供的示例差别不大。然后使用方法writeFullModel将训练后的训练模型保存到文本文件中。然后可以使用此方法加载
WordVectorSerializer.loadFullModel(MODEL_FILE_NAME);
问题是,当模型变大时,它似乎不起作用。对于大小为120Mb的模型文件,我不断得到这个
Exception in thread "main" java.lang.IllegalArgumentException: Illegal slice 7151
at org.nd4j.linalg.api.ndarray.BaseNDArray.slice(BaseNDArray.java:2852)
at org.nd4j.linalg.api.ndarray.BaseNDArray.tensorAlongDimension(BaseNDArray.java:753)
at org.nd4j.linalg.api.ndarray.BaseNDArray.vectorAlongDimension(BaseNDArray.java:830)
at org.nd4j.linalg.api.ndarray.BaseNDArray.getRow(BaseNDArray.java:3628)
at org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.loadFullModel(WordVectorSerializer.java:523)
尽管如此,它仍然适用于小型模型文件。 任何帮助将不胜感激,非常感谢你。
答案 0 :(得分:2)
抛出IllegalArgumentException时,必须检查Java堆栈跟踪中的调用堆栈,并找到产生错误参数的方法。
Here是BaseNDArray.java类并查看行或搜索您收到错误的方法! 看到这个方法:
public INDArray slice(int slice) {
int slices = slices();
if(slice >= slices)
throw new IllegalArgumentException("Illegal slice " + slice);
所以,对于120MB文件,切片> =切片! 看看这是否有帮助,并回复你是如何解决问题的!