我正在使用Deeplearning项目中的dl4j-examples
:https://github.com/deeplearning4j/dl4j-examples
我正在寻找完成的过程,以便我可以用这个模块运行其他东西。
但我正在寻找完成迭代的过程。它一直在继续。
我不明白为什么要花这么长时间训练NLP。我不知道需要多长时间才能完成。如果有人有任何想法请分享,让我知道完成这个过程需要多少时间?如果可能的话,我怎样才能让这个过程更快? 我的系统配置包括:
Ram: 15.5 GiB
processor: Intel® Core™ i3-4150 CPU @ 3.50GHz × 4
OS: 64-bit
我的代码需要做哪些更改?
这是Word2VecSentimentRNN的代码:
package org.deeplearning4j.examples.recurrent.word2vecsentiment;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.*;
import java.net.URL;
public class Word2VecSentimentRNN {
/** Data URL for downloading */
public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
/** Location to save and extract the training/testing data */
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");
/** Location (local file system) for the Google News vectors. Set this manually. */
public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
public static void main(String[] args) throws Exception {
if(WORD_VECTORS_PATH.startsWith("/PATH/TO/YOUR/VECTORS/")){
throw new RuntimeException("Please set the WORD_VECTORS_PATH before running this example");
}
//Download and extract data
downloadData();
int batchSize = 64; //Number of examples in each minibatch
int vectorSize = 300; //Size of the word vectors. 300 in the Google News model
int nEpochs = 1; //Number of epochs (full passes of training data) to train on
int truncateReviewsToLength = 256; //Truncate reviews with length (# words) greater than this
//Set up network configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(Updater.ADAM).adamMeanDecay(0.9).adamVarDecay(0.999)
.regularization(true).l2(1e-5)
.weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
.learningRate(2e-2)
.list()
.layer(0, new GravesLSTM.Builder().nIn(vectorSize).nOut(256)
.activation(Activation.TANH).build())
.layer(1, new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(256).nOut(2).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
//DataSetIterators for training and testing respectively
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
SentimentExampleIterator train = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, true);
SentimentExampleIterator test = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, false);
System.out.println("Starting training");
for (int i = 0; i < nEpochs; i++) {
net.fit(train);
train.reset();
System.out.println("Epoch " + i + " complete. Starting evaluation:");
//Run evaluation. This is on 25k reviews, so can take some time
Evaluation evaluation = new Evaluation();
while (test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatureMatrix();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = net.output(features, false, inMask, outMask);
evaluation.evalTimeSeries(lables, predicted, outMask);
}
test.reset();
System.out.println(evaluation.stats());
}
//After training: load a single example and generate predictions
File firstPositiveReviewFile = new File(FilenameUtils.concat(DATA_PATH, "aclImdb/test/pos/0_10.txt"));
String firstPositiveReview = FileUtils.readFileToString(firstPositiveReviewFile);
INDArray features = test.loadFeaturesFromString(firstPositiveReview, truncateReviewsToLength);
INDArray networkOutput = net.output(features);
int timeSeriesLength = networkOutput.size(2);
INDArray probabilitiesAtLastWord = networkOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength - 1));
System.out.println("\n\n-------------------------------");
System.out.println("First positive review: \n" + firstPositiveReview);
System.out.println("\n\nProbabilities at last time step:");
System.out.println("p(positive): " + probabilitiesAtLastWord.getDouble(0));
System.out.println("p(negative): " + probabilitiesAtLastWord.getDouble(1));
System.out.println("----- Example complete -----");
}
private static void downloadData() throws Exception {
//Create directory if required
File directory = new File(DATA_PATH);
if(!directory.exists()) directory.mkdir();
//Download file:
String archizePath = DATA_PATH + "aclImdb_v1.tar.gz";
File archiveFile = new File(archizePath);
String extractedPath = DATA_PATH + "aclImdb";
File extractedFile = new File(extractedPath);
if( !archiveFile.exists() ){
System.out.println("Starting data download (80MB)...");
FileUtils.copyURLToFile(new URL(DATA_URL), archiveFile);
System.out.println("Data (.tar.gz file) downloaded to " + archiveFile.getAbsolutePath());
//Extract tar.gz file to output directory
extractTarGz(archizePath, DATA_PATH);
} else {
//Assume if archive (.tar.gz) exists, then data has already been extracted
System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
if( !extractedFile.exists()){
//Extract tar.gz file to output directory
extractTarGz(archizePath, DATA_PATH);
} else {
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
}
}
}
private static final int BUFFER_SIZE = 4096;
private static void extractTarGz(String filePath, String outputPath) throws IOException {
int fileCount = 0;
int dirCount = 0;
System.out.print("Extracting files");
try(TarArchiveInputStream tais = new TarArchiveInputStream(
new GzipCompressorInputStream( new BufferedInputStream( new FileInputStream(filePath))))){
TarArchiveEntry entry;
/** Read the tar entries using the getNextEntry method **/
while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
//System.out.println("Extracting file: " + entry.getName());
//Create directories as required
if (entry.isDirectory()) {
new File(outputPath + entry.getName()).mkdirs();
dirCount++;
}else {
int count;
byte data[] = new byte[BUFFER_SIZE];
FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
BufferedOutputStream dest = new BufferedOutputStream(fos,BUFFER_SIZE);
while ((count = tais.read(data, 0, BUFFER_SIZE)) != -1) {
dest.write(data, 0, count);
}
dest.close();
fileCount++;
}
if(fileCount % 1000 == 0) System.out.print(".");
}
}
System.out.println("\n" + fileCount + " files and " + dirCount + " directories extracted to: " + outputPath);
}
}