使用dl4j和FCN网络进行语义分割

时间:2018-07-14 15:42:13

标签: java bigdata artificial-intelligence dl4j

我已经检查了dl4j示例并执行了 AnimalsClassification 示例,以成功进行测试。

  

我必须训练,评估和预测(使用)像 UNet 这样的语义分割算法,因为输入图像的大小不等于 FCN

     
    

并将 AnimalsClassification 示例的网络从该链接更改为 UNet

         
      

https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/UNet.java

    
  

但出现错误。 您能帮我解决这个错误吗?

错误:

SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
19:22:39,951 INFO  ~ Load data....
19:22:40,696 INFO  ~ Build model....
19:22:41,363 INFO  ~ Loaded [CpuBackend] backend
19:22:44,955 INFO  ~ Number of threads used for NativeOps: 2
19:22:45,555 INFO  ~ Number of threads used for BLAS: 2
19:22:45,562 INFO  ~ Backend used: [CPU]; OS: [Linux]
19:22:45,562 INFO  ~ Cores: [2]; Memory: [1.3GB];
19:22:45,562 INFO  ~ Blas vendor: [OPENBLAS]
19:22:56,425 WARN  ~ Layer "Layer not named" distribution is set but will not be applied unless weight init is set to WeighInit.DISTRIBUTION.
Exception in thread "main" java.lang.IllegalStateException: Invalid configuration: network has no inputs. Use .addInputs(String...) to label    (and give an ordering to) the network inputs
    at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.validate(ComputationGraphConfiguration.java:279)
    at org.deeplearning4j.nn.conf.ComputationGraphConfiguration$GraphBuilder.build(ComputationGraphConfiguration.java:918)
    at org.deeplearning4j.examples.convolution.AnimalsClassification.graphBuilder(AnimalsClassification.java:443)
    at org.deeplearning4j.examples.convolution.AnimalsClassification.run(AnimalsClassification.java:145)
    at org.deeplearning4j.examples.convolution.AnimalsClassification.main(AnimalsClassification.java:447)

Process finished with exit code 1

我更改的代码是:

package org.deeplearning4j.examples.convolution;

import lombok.Builder;
import org.apache.commons.io.FilenameUtils;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.linalg.schedule.StepSchedule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.util.List;
import java.util.Random;

import static java.lang.Math.toIntExact;

/**
 * Animal Classification
 *
 * Example classification of photos from 4 different animals (bear, duck, deer, turtle).
 *
 * References:
 *  - U.S. Fish and Wildlife Service (animal sample dataset): http://digitalmedia.fws.gov/cdm/
 *  - Tiny ImageNet Classification with CNN: http://cs231n.stanford.edu/reports/2015/pdfs/leonyao_final.pdf
 *
 * CHALLENGE: Current setup gets low score results. Can you improve the scores? Some approaches:
 *  - Add additional images to the dataset
 *  - Apply more transforms to dataset
 *  - Increase epochs
 *  - Try different model configurations
 *  - Tune by adjusting learning rate, updaters, activation & loss functions, regularization, ...
 */

public class AnimalsClassification {
    protected static final Logger log = LoggerFactory.getLogger(AnimalsClassification.class);
    protected static int height = 100;
    protected static int width = 100;
    protected static int channels = 3;
    protected static int batchSize = 20;

//    protected static long seed = 42;
    private static long seed = 1234;
    protected static Random rng = new Random(seed);
    protected static int epochs = 50;
    protected static double splitTrainTest = 0.8;
    protected static boolean save = false;

    protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out
    private int numLabels;

    public void run(String[] args) throws Exception {

        log.info("Load data....");
        /**cd
         * Data Setup -> organize and limit data file paths:
         *  - mainPath = path to image files
         *  - fileSplit = define basic dataset split with limits on format
         *  - pathFilter = define additional file load filter to limit size and balance batch content
         **/
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/");
        FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
        int numExamples = toIntExact(fileSplit.length());
        numLabels = fileSplit.getRootDir().listFiles(File::isDirectory).length; //This only works if your root is clean: only label subdirs.
        BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);

        /**
         * Data Setup -> train test split
         *  - inputSplit = define train and test split
         **/
        InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest);
        InputSplit trainData = inputSplit[0];
        InputSplit testData = inputSplit[1];

        /**
         * Data Setup -> transformation
         *  - Transform = how to tranform images and generate large dataset to train on
         **/
//        ImageTransform flipTransform1 = new FlipImageTransform(rng);
//        ImageTransform flipTransform2 = new FlipImageTransform(new Random(123));
//        ImageTransform warpTransform = new WarpImageTransform(rng, 42);
//        ImageTransform colorTransform = new ColorConversionTransform(new Random(seed), COLOR_BGR2YCrCb);
//        List<ImageTransform> transforms = Arrays.asList(new ImageTransform[]{flipTransform1, warpTransform, flipTransform2});

        /**
         * Data Setup -> normalization
         *  - how to normalize images and generate large dataset to train on
         **/
        DataNormalization scaler = new ImagePreProcessingScaler(0, 1);

        log.info("Build model....");

        // Uncomment below to try AlexNet. Note change height and width to at least 100
//        MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init();

//        MultiLayerNetwork network;
//        switch (modelType) {
//            case "LeNet":
//                network = lenetModel();
//                break;
//            case "AlexNet":
//                network = alexnetModel();
//                break;
//            case "custom":
//                network = customModel();
//                break;
//            default:
//                throw new InvalidInputTypeException("Incorrect model provided.");
//        }
        ComputationGraph network = new ComputationGraph(graphBuilder());
        network.init();
       // network.setListeners(new ScoreIterationListener(listenerFreq));
        UIServer uiServer = UIServer.getInstance();
        StatsStorage statsStorage = new InMemoryStatsStorage();
        uiServer.attach(statsStorage);
        network.setListeners(new StatsListener( statsStorage),new ScoreIterationListener(1));
        /**
         * Data Setup -> define how to load data into net:
         *  - recordReader = the reader that loads and converts image data pass in inputSplit to initialize
         *  - dataIter = a generator that only loads one batch at a time into memory to save memory
         *  - trainIter = uses MultipleEpochsIterator to ensure model runs through the data for all epochs
         **/
        ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
        DataSetIterator dataIter;
        MultipleEpochsIterator trainIter;


        log.info("Train model....");
        // Train without transformations
        recordReader.initialize(trainData, null);
        dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
        scaler.fit(dataIter);
        dataIter.setPreProcessor(scaler);
        trainIter = new MultipleEpochsIterator(epochs, dataIter);
        network.fit(trainIter);

        // Train with transformations
/*        for (ImageTransform transform : transforms) {
            System.out.print("\nTraining on transformation: " + transform.getClass().toString() + "\n\n");
            recordReader.initialize(trainData, transform);
            dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
            scaler.fit(dataIter);
            dataIter.setPreProcessor(scaler);
            trainIter = new MultipleEpochsIterator(epochs, dataIter);
            network.fit(trainIter);
        }*/

        log.info("Evaluate model....");
        recordReader.initialize(testData);
        dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
        scaler.fit(dataIter);
        dataIter.setPreProcessor(scaler);
        Evaluation eval = network.evaluate(dataIter);
        log.info(eval.stats(true));

        // Example on how to get predict results with trained model. Result for first example in minibatch is printed
        dataIter.reset();
        DataSet testDataSet = dataIter.next();
        List<String> allClassLabels = recordReader.getLabels();
        int labelIndex = testDataSet.getLabels().argMax(1).getInt(0);
//        int[] predictedClasses = network.predict(testDataSet.getFeatures());
        String expectedResult = allClassLabels.get(labelIndex);
//        String modelPrediction = allClassLabels.get(predictedClasses[0]);
//        System.out.print("\nFor a single example that is labeled " + expectedResult + " the model predicted " + modelPrediction + "\n\n");

        if (save) {
            log.info("Save model....");
            String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "src/main/resources/");
            ModelSerializer.writeModel(network, basePath + "model.bin", true);
        }
        log.info("****************Example finished********************");
    }

    private ConvolutionLayer convInit(String name, int in, int out, int[] kernel, int[] stride, int[] pad, double bias) {
        return new ConvolutionLayer.Builder(kernel, stride, pad).name(name).nIn(in).nOut(out).biasInit(bias).build();
    }

    private ConvolutionLayer conv3x3(String name, int out, double bias) {
        return new ConvolutionLayer.Builder(new int[]{3,3}, new int[] {1,1}, new int[] {1,1}).name(name).nOut(out).biasInit(bias).build();
    }

    private ConvolutionLayer conv5x5(String name, int out, int[] stride, int[] pad, double bias) {
        return new ConvolutionLayer.Builder(new int[]{5,5}, stride, pad).name(name).nOut(out).biasInit(bias).build();
    }

    private SubsamplingLayer maxPool(String name,  int[] kernel) {
        return new SubsamplingLayer.Builder(kernel, new int[]{2,2}).name(name).build();
    }

    private DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) {
        return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build();
    }

    public MultiLayerNetwork lenetModel() {
        /**
         * Revisde Lenet Model approach developed by ramgo2 achieves slightly above random
         * Reference: https://gist.github.com/ramgo2/833f12e92359a2da9e5c2fb6333351c5
         **/
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .l2(0.005)
            .activation(Activation.RELU)
            .weightInit(WeightInit.XAVIER)
            .updater(new Nesterovs(0.0001,0.9))
            .list()
            .layer(0, convInit("cnn1", channels, 50 ,  new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0}, 0))
            .layer(1, maxPool("maxpool1", new int[]{2,2}))
            .layer(2, conv5x5("cnn2", 100, new int[]{5, 5}, new int[]{1, 1}, 0))
            .layer(3, maxPool("maxool2", new int[]{2,2}))
            .layer(4, new DenseLayer.Builder().nOut(500).build())
            .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(numLabels)
                .activation(Activation.SOFTMAX)
                .build())
            .backprop(true).pretrain(false)
            .setInputType(InputType.convolutional(height, width, channels))
            .build();

        return new MultiLayerNetwork(conf);

    }

    public MultiLayerNetwork alexnetModel() {
        /**
         * AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks
         * and the imagenetExample code referenced.
         * http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
         **/

        double nonZeroBias = 1;
        double dropOut = 0.5;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .weightInit(WeightInit.DISTRIBUTION)
            .dist(new NormalDistribution(0.0, 0.01))
            .activation(Activation.RELU)
            .updater(new Nesterovs(new StepSchedule(ScheduleType.ITERATION, 1e-2, 0.1, 100000), 0.9))
            .biasUpdater(new Nesterovs(new StepSchedule(ScheduleType.ITERATION, 2e-2, 0.1, 100000), 0.9))
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients
            .l2(5 * 1e-4)
            .list()
            .layer(0, convInit("cnn1", channels, 96, new int[]{11, 11}, new int[]{4, 4}, new int[]{3, 3}, 0))
            .layer(1, new LocalResponseNormalization.Builder().name("lrn1").build())
            .layer(2, maxPool("maxpool1", new int[]{3,3}))
            .layer(3, conv5x5("cnn2", 256, new int[] {1,1}, new int[] {2,2}, nonZeroBias))
            .layer(4, new LocalResponseNormalization.Builder().name("lrn2").build())
            .layer(5, maxPool("maxpool2", new int[]{3,3}))
            .layer(6,conv3x3("cnn3", 384, 0))
            .layer(7,conv3x3("cnn4", 384, nonZeroBias))
            .layer(8,conv3x3("cnn5", 256, nonZeroBias))
            .layer(9, maxPool("maxpool3", new int[]{3,3}))
            .layer(10, fullyConnected("ffn1", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
            .layer(11, fullyConnected("ffn2", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
            .layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .name("output")
                .nOut(numLabels)
                .activation(Activation.SOFTMAX)
                .build())
            .backprop(true)
            .pretrain(false)
            .setInputType(InputType.convolutional(height, width, channels))
            .build();

        return new MultiLayerNetwork(conf);

    }

    public static MultiLayerNetwork customModel() {
        /**
         * Use this method to build your own custom model.
         **/
        return null;
    }

    @Builder.Default private int[] inputShape = new int[] {3, 512, 512};
    @Builder.Default private int numClasses = 0;
    @Builder.Default private WeightInit weightInit = WeightInit.RELU;
    @Builder.Default private IUpdater updater = new AdaDelta();
    @Builder.Default private CacheMode cacheMode = CacheMode.NONE;
//    @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
    @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.SINGLE;
    @Builder.Default private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;

    public ComputationGraphConfiguration graphBuilder() {

        ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(updater)
            .weightInit(weightInit)
            .dist(new TruncatedNormalDistribution(0.0, 0.5))
            .l2(5e-5)
            .miniBatch(true)
            .cacheMode(cacheMode)
            .trainingWorkspaceMode(workspaceMode)
            .inferenceWorkspaceMode(workspaceMode)
            .graphBuilder();


        graph
            .addLayer("conv1-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "input")
            .addLayer("conv1-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv1-1")
            .addLayer("pool1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
                .build(), "conv1-2")

            .addLayer("conv2-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "pool1")
            .addLayer("conv2-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv2-1")
            .addLayer("pool2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
                .build(), "conv2-2")

            .addLayer("conv3-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "pool2")
            .addLayer("conv3-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv3-1")
            .addLayer("pool3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
                .build(), "conv3-2")

            .addLayer("conv4-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "pool3")
            .addLayer("conv4-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv4-1")
            .addLayer("drop4", new DropoutLayer.Builder(0.5).build(), "conv4-2")
            .addLayer("pool4", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
                .build(), "drop4")

            .addLayer("conv5-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1024)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "pool4")
            .addLayer("conv5-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1024)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv5-1")
            .addLayer("drop5", new DropoutLayer.Builder(0.5).build(), "conv5-2")

            // up6
            .addLayer("up6-1", new Upsampling2D.Builder(2).build(), "drop5")
            .addLayer("up6-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(512)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "up6-1")
            .addVertex("merge6", new MergeVertex(), "drop4", "up6-2")
            .addLayer("conv6-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "merge6")
            .addLayer("conv6-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv6-1")

            // up7
            .addLayer("up7-1", new Upsampling2D.Builder(2).build(), "conv6-2")
            .addLayer("up7-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(256)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "up7-1")
            .addVertex("merge7", new MergeVertex(), "conv3-2", "up7-2")
            .addLayer("conv7-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "merge7")
            .addLayer("conv7-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv7-1")

            // up8
            .addLayer("up8-1", new Upsampling2D.Builder(2).build(), "conv7-2")
            .addLayer("up8-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(128)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "up8-1")
            .addVertex("merge8", new MergeVertex(), "conv2-2", "up8-2")
            .addLayer("conv8-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "merge8")
            .addLayer("conv8-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv8-1")

            // up9
            .addLayer("up9-1", new Upsampling2D.Builder(2).build(), "conv8-2")
            .addLayer("up9-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(64)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "up8-1")
            .addVertex("merge9", new MergeVertex(), "conv1-2", "up9-2")
            .addLayer("conv9-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "merge9")
            .addLayer("conv9-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv9-1")
            .addLayer("conv9-3", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(2)
                .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.RELU).build(), "conv9-2")

            .addLayer("conv10", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1)
                .convolutionMode(ConvolutionMode.Truncate).cudnnAlgoMode(cudnnAlgoMode)
                .activation(Activation.SIGMOID).build(), "conv9-3")
            .addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.MCXENT).build(), "conv10")

            .setOutputs("output").backprop(true).pretrain(false);

        return graph.build();
    }

    public static void main(String[] args) throws Exception {
        new AnimalsClassification().run(args);
    }

}

非常感谢。

1 个答案:

答案 0 :(得分:0)

它指出了堆栈跟踪中的错误。您需要在图形上调用setInputs。这类似于我们在底部设置输出的方式。 请仔细阅读错误消息。