在deeplearning4j(卷积网络)中对经过训练的自定义模型分类新图像

时间:2017-04-08 10:29:30

标签: deeplearning4j

我是deeplearning4J的新手。我已经尝试了word2vec功能,一切都很好。但现在我对图像分类有点困惑。我正在玩这个例子:

https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/AnimalsClassification.java

我更改了"保存" flag为true,我的模型存储在model.bin文件中。 现在出现了问题部分(如果这听起来很愚蠢,我很抱歉,也许我错过了一些非常明显的东西)

我创建了一个名为AnimalClassifier的单独类,其目的是从model.bin文件加载模型,从中恢复神经网络,然后使用恢复的网络对单个图像进行分类。对于我创建的这个单一图像" temp"文件夹 - > dl4j-examples / src / main / resources / animals / temp /其中我把之前在动物训练过程中使用过的北极熊图片放在AnimalsClassification.java中(我想确保图像被正确分类 - 因此我重复使用了图片来自"熊"文件夹)。

这是我的代码试图对北极熊进行分类:

protected static int height = 100;
    protected static int width = 100;
    protected static int channels = 3;
    protected static int numExamples = 1;
    protected static int numLabels = 1;
    protected static int batchSize = 10;

    protected static long seed = 42;
    protected static Random rng = new Random(seed);
    protected static int listenerFreq = 1;
    protected static int iterations = 1;
    protected static int epochs = 7;
    protected static double splitTrainTest = 0.8;
    protected static int nCores = 2;
    protected static boolean save = true;

    protected static String modelType = "AlexNet"; //

    public static void main(String[] args) throws Exception {

        String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/");
        MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(basePath + "model.bin", true);

        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/temp/");
        FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
        BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);


        InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
        InputSplit analysedData = inputSplit[0];


        ImageRecordReader recordReader = new ImageRecordReader(height, width, channels);
        recordReader.initialize(analysedData);
        DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 0, 4);
        while (dataIter.hasNext()) {
            DataSet testDataSet = dataIter.next();

            String expectedResult = testDataSet.getLabelName(0);
            List<String> predict = multiLayerNetwork.predict(testDataSet);
            String modelResult = predict.get(0);
            System.out.println("\nFor example that is labeled " + expectedResult + " the model predicted " + modelResult + "\n\n");
        }
    }

运行之后,我收到错误:

java.lang.UnsupportedOperationException     at org.datavec.api.writable.ArrayWritable.toInt(ArrayWritable.java:47)     at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:275)     at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:186)     at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:389)     at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52)     at org.deeplearningarningj.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:66) 断开与目标VM的连接,地址:&#39; 127.0.0.1:63967&#39;,transport:&#39; socket&#39; 线程&#34; main&#34;中的例外情况java.lang.IllegalStateException:未在此数据集上定义标签名称。添加标签名称以使用带有id的getLabelName。     at org.nd4j.linalg.dataset.DataSet.getLabelName(DataSet.java:1106)     at org.deeplearningarningjj.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:68)

我可以看到MultiLayerNetwork.java中有一个方法public void setLabels(INDArray labels)但我不知道如何使用(特别是当它作为参数INDArray时)。

我也很困惑,为什么我必须在RecordReaderDataSetIterator的构造函数中指定可能的标签数量。我希望该模型已经知道要使用哪些标签(不应该使用自动训练期间使用的标签吗?)。我想,也许我正在以完全错误的方式加载图片......

总而言之,我想简单地完成以下任务:

  1. 从模型中恢复网络(这是有效的)
  2. 加载要分类的图像(也正常工作)
  3. 使用培训期间使用的相同标签(熊,鹿,鸭,龟)对此图像进行分类(棘手的部分)
  4. 提前感谢您的帮助或任何提示!

1 个答案:

答案 0 :(得分:0)

在此总结您的多个问题: 图像记录是集合中的2个条目。第二个是标签。标签索引是相对于您传入的记录的种类

问题的第二部分: 多个条目可以是数据集的一部分。该列表引用了小批量中特定的项目标签。