我想打印用于分类的traindata / testdata标签。这是两个输入的定义(使用deep4j)。
InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest);
InputSplit trainData = inputSplit[0];
InputSplit testData = inputSplit[1];
然后在DataSetIterator中进行转换,如下所示:
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(trainData, null);
trainIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
然后我想打印在此函数中每个迭代器中找到的每个标签的示例数:
public void print(DataSetIterator iter){
HashMap<String, Integer> hash = new HashMap<String, Integer>();
while(iter.hasNext()){
DataSet example = iter.next();
for(int i = 0 ; i<numLabels ; i++){
if(example.getLabels().getDouble(i)==1.){
String label = example.getLabelName(i);
if(hash.containsKey(label))
hash.put(label, hash.get(label)+1);
else
hash.put(label, 1);
}
}
}
for (String label: hash.keySet()){
System.out.println(" label : " + label.toString() + ", " + hash.get(label) + " examples");
}
}
问题在于它每个标签只显示一个示例,而应该还有更多......当我不使用fileSplit.sample()
拆分我的数据集时,该函数会显示正确数量的示例。
有什么建议吗?
答案 0 :(得分:0)
如果您使用数据集,则可以使用dataset.getFeatureMatrix()和dataset.getLabels()的toString()
如果你只想打印标签计数,你可以使用dataset.labelCounts()我会更多地看看dl4j javadoc: http://deeplearning4j.org/doc