显示数据标签[deep4j]

时间:2016-11-11 12:27:51

标签: java deep-learning deeplearning4j

我想打印用于分类的traindata / testdata标签。这是两个输入的定义(使用deep4j)。

    InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest);
    InputSplit trainData = inputSplit[0];
    InputSplit testData = inputSplit[1];

然后在DataSetIterator中进行转换,如下所示:

    ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
    recordReader.initialize(trainData, null);
    trainIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);

然后我想打印在此函数中每个迭代器中找到的每个标签的示例数:

public void print(DataSetIterator iter){

    HashMap<String, Integer> hash = new HashMap<String, Integer>();

    while(iter.hasNext()){
        DataSet example = iter.next();
        for(int i = 0 ; i<numLabels ; i++){
            if(example.getLabels().getDouble(i)==1.){
                String label = example.getLabelName(i);
                if(hash.containsKey(label))
                    hash.put(label, hash.get(label)+1);
                else
                    hash.put(label, 1);
            }
        }
    }

    for (String label: hash.keySet()){
        System.out.println("   label : " + label.toString() + ", " + hash.get(label) + " examples");
    }
}

问题在于它每个标签只显示一个示例,而应该还有更多......当我不使用fileSplit.sample()拆分我的数据集时,该函数会显示正确数量的示例。 有什么建议吗?

1 个答案:

答案 0 :(得分:0)

如果您使用数据集,则可以使用dataset.getFeatureMatrix()和dataset.getLabels()的toString()

如果你只想打印标签计数,你可以使用dataset.labelCounts()我会更多地看看dl4j javadoc: http://deeplearning4j.org/doc