我正在尝试按类索引过滤我的CIFAR-100 ndarray,这是我的代码:
def get_cifar100(folder, class_idx):
train_fname = os.path.join(folder, 'train')
test_fname = os.path.join(folder, 'test')
data_dict = unpickle(train_fname)
train_data = data_dict['data']
train_fine_labels = data_dict['fine_labels']
train_coarse_labels = data_dict['coarse_labels']
# Filtering process
filt_tdata = numpy.empty((0))
for i, v in enumerate(train_coarse_labels):
if v == class_idx:
filt_tdata = numpy.append(filt_tdata, train_data[i])
data_dict = unpickle(test_fname)
test_data = data_dict['data']
test_fine_labels = data_dict['fine_labels']
test_coarse_labels = data_dict['coarse_labels']
bm = unpickle(os.path.join(folder, 'meta'))
clabel_names = bm['coarse_label_names']
flabel_names = bm['fine_label_names']
return data_dict, filt_tdata, numpy.array(train_coarse_labels), numpy.array(train_fine_labels), test_data, numpy.array(test_coarse_labels), numpy.array(test_fine_labels), clabel_names, flabel_names
datapath = "./data/cifar-100-python"
data_dict, tr_data100, tr_clabels100, tr_flabels100, te_data100, te_clabels100, te_flabels100, clabel_names100, flabel_names100 = get_cifar100(datapath, 4)
print(len(tr_data100))
我想基于class_idx = 4(train_coarse_labels)过滤train_data。原始数组的大小为50000,过滤时应为5000。但是,我获得了超过其原始大小(700万++)。我的功能出了什么问题?感谢。