我想创建一个与cifar10格式相同的数据集(data.npy)。示例omniglot数据集包含
data = [0001_01.png,0001_02.png,0001_03.png,0001_04.png,0002_01.png,0002_02.png,0002_03.png]
class_name,filename = data[0].split('_')
每个文件都附加了class.There有1600个类,每个类有20个样本。预期的数据集(data.npy)形状为(1600,20,784)
但我得到的形状是(1,20,784)。下面给出的是片段
classes = []
examples = []
prev = files[0].split('_')[0]
for f in files:
cur_id = f.split('_')[0]
cur_pic = misc.imresize(misc.imread('new_data/' + f),[28,28])
cur_pic = (np.float32(cur_pic)/255).flatten()
if prev == cur_id:
examples.append(cur_pic)
else:
classes.append(examples)
examples = [cur_pic]
prev = cur_id
np.save('data',np.asarray(classes))
任何建议都会有很大帮助。以上代码取自https://github.com/zergylord/oneshot