dl4j canova示例不起作用

时间:2016-01-28 13:39:22

标签: java deeplearning4j

Deeplearning4j canova示例无效。我将eval.stats的输出作为NaN(准确度).I

import org.slf4j.LoggerFactory;


public class ImageClassifierExample {

    public static void main(String[] args) throws IOException, InterruptedException {


        // Path to the labeled images
        String labeledPath = System.getProperty("user.home")+"/lfw";
         List<String> labels = new ArrayList<>();
        for(File f : new File(labeledPath).listFiles()) {
            labels.add(f.getName());
        }
        // Instantiating a RecordReader pointing to the data path with the specified
        // height and width for each image.
        RecordReader recordReader = new ImageRecordReader(28, 28, true,labels);
        recordReader.initialize(new FileSplit(new File(labeledPath)));

        // Canova to Dl4j
        DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 784,labels.size());

        // Creating configuration for the neural net.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
                .constrainGradientToUnitNorm(true)
                .weightInit(WeightInit.DISTRIBUTION)
                .dist(new NormalDistribution(1,1e-5))
                .iterations(100).learningRate(1e-3)
                .nIn(784).nOut(labels.size())
                .visibleUnit(org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit.GAUSSIAN)
                .hiddenUnit(org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit.RECTIFIED)
                .layer(new org.deeplearning4j.nn.conf.layers.RBM())
                .list(4).hiddenLayerSizes(600, 250, 100).override(3, new ConfOverride() {
                    @Override
                    public void overrideLayer(int i, NeuralNetConfiguration.Builder builder) {
                        if (i == 3) {
                            builder.layer(new org.deeplearning4j.nn.conf.layers.OutputLayer());
                            builder.activationFunction("softmax");
                            builder.lossFunction(LossFunctions.LossFunction.MCXENT);

                        }
                    }
                }).build();

        MultiLayerNetwork network = new MultiLayerNetwork(conf);
        network.setListeners(Arrays.<IterationListener>asList(new ScoreIterationListener(10)));

        // Training
        while(iter.hasNext()){
            DataSet next = iter.next();
            network.fit(next);
        }

        // Testing -- We're not doing split test and train
        // Using the same training data as test.
        iter.reset();
        Evaluation eval = new Evaluation();
        while(iter.hasNext()){
            DataSet next = iter.next();
            INDArray predict2 = network.output(next.getFeatureMatrix());
            eval.eval(next.getLabels(), predict2);
        }

        System.out.println(eval.stats());
    }
}

1 个答案:

答案 0 :(得分:2)

您的NN配置看起来像是基于非常古老的dl4j版本。最新版本是:

DL4j:0.4-rc3.8 ND4j:0.4-rc3.8 卡诺瓦:0.0.0.14

请尝试使用最新版本