过滤Python Numpy ndarray

时间:2016-11-20 13:12:42

标签: python numpy machine-learning deep-learning keras

我正在尝试按类索引过滤我的CIFAR-100 ndarray,这是我的代码:

def get_cifar100(folder, class_idx):
    train_fname = os.path.join(folder, 'train')
    test_fname = os.path.join(folder, 'test')
    data_dict = unpickle(train_fname)
    train_data = data_dict['data']
    train_fine_labels = data_dict['fine_labels']
    train_coarse_labels = data_dict['coarse_labels']

    # Filtering process
    filt_tdata = numpy.empty((0))
    for i, v in enumerate(train_coarse_labels):
        if v == class_idx:
            filt_tdata = numpy.append(filt_tdata, train_data[i])

    data_dict = unpickle(test_fname)
    test_data = data_dict['data']
    test_fine_labels = data_dict['fine_labels']
    test_coarse_labels = data_dict['coarse_labels']

    bm = unpickle(os.path.join(folder, 'meta'))
    clabel_names = bm['coarse_label_names']
    flabel_names = bm['fine_label_names']

    return data_dict, filt_tdata, numpy.array(train_coarse_labels), numpy.array(train_fine_labels), test_data, numpy.array(test_coarse_labels), numpy.array(test_fine_labels), clabel_names, flabel_names

datapath = "./data/cifar-100-python"
data_dict, tr_data100, tr_clabels100, tr_flabels100, te_data100,   te_clabels100, te_flabels100, clabel_names100, flabel_names100 =    get_cifar100(datapath, 4)

print(len(tr_data100))

我想基于class_idx = 4(train_coarse_labels)过滤train_data。原始数组的大小为50000,过滤时应为5000。但是,我获得了超过其原始大小(700万++)。我的功能出了什么问题?感谢。

0 个答案:

没有答案