我是deeplearning4J的新手。我已经尝试了word2vec功能,一切都很好。但现在我对图像分类有点困惑。我正在玩这个例子:
我更改了"保存" flag为true,我的模型存储在model.bin文件中。 现在出现了问题部分(如果这听起来很愚蠢,我很抱歉,也许我错过了一些非常明显的东西)
我创建了一个名为AnimalClassifier的单独类,其目的是从model.bin文件加载模型,从中恢复神经网络,然后使用恢复的网络对单个图像进行分类。对于我创建的这个单一图像" temp"文件夹 - > dl4j-examples / src / main / resources / animals / temp /其中我把之前在动物训练过程中使用过的北极熊图片放在AnimalsClassification.java中(我想确保图像被正确分类 - 因此我重复使用了图片来自"熊"文件夹)。
这是我的代码试图对北极熊进行分类:
protected static int height = 100;
protected static int width = 100;
protected static int channels = 3;
protected static int numExamples = 1;
protected static int numLabels = 1;
protected static int batchSize = 10;
protected static long seed = 42;
protected static Random rng = new Random(seed);
protected static int listenerFreq = 1;
protected static int iterations = 1;
protected static int epochs = 7;
protected static double splitTrainTest = 0.8;
protected static int nCores = 2;
protected static boolean save = true;
protected static String modelType = "AlexNet"; //
public static void main(String[] args) throws Exception {
String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/");
MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(basePath + "model.bin", true);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/temp/");
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
InputSplit analysedData = inputSplit[0];
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels);
recordReader.initialize(analysedData);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 0, 4);
while (dataIter.hasNext()) {
DataSet testDataSet = dataIter.next();
String expectedResult = testDataSet.getLabelName(0);
List<String> predict = multiLayerNetwork.predict(testDataSet);
String modelResult = predict.get(0);
System.out.println("\nFor example that is labeled " + expectedResult + " the model predicted " + modelResult + "\n\n");
}
}
运行之后,我收到错误:
java.lang.UnsupportedOperationException at org.datavec.api.writable.ArrayWritable.toInt(ArrayWritable.java:47) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:275) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:186) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:389) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52) at org.deeplearningarningj.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:66) 断开与目标VM的连接,地址:&#39; 127.0.0.1:63967&#39;,transport:&#39; socket&#39; 线程&#34; main&#34;中的例外情况java.lang.IllegalStateException:未在此数据集上定义标签名称。添加标签名称以使用带有id的getLabelName。 at org.nd4j.linalg.dataset.DataSet.getLabelName(DataSet.java:1106) at org.deeplearningarningjj.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:68)
我可以看到MultiLayerNetwork.java中有一个方法public void setLabels(INDArray labels)但我不知道如何使用(特别是当它作为参数INDArray时)。
我也很困惑,为什么我必须在RecordReaderDataSetIterator的构造函数中指定可能的标签数量。我希望该模型已经知道要使用哪些标签(不应该使用自动训练期间使用的标签吗?)。我想,也许我正在以完全错误的方式加载图片......
总而言之,我想简单地完成以下任务:
提前感谢您的帮助或任何提示!
答案 0 :(得分:0)
在此总结您的多个问题: 图像记录是集合中的2个条目。第二个是标签。标签索引是相对于您传入的记录的种类。
问题的第二部分: 多个条目可以是数据集的一部分。该列表引用了小批量中特定行的项目标签。