您好我下载了cifar-10数据集。
在我的代码中,它加载数据集如下。
import cv2
import numpy as np
from keras.datasets import cifar10
from keras import backend as K
from keras.utils import np_utils
nb_train_samples = 3000 # 3000 training samples
nb_valid_samples = 100 # 100 validation samples
num_classes = 10
def load_cifar10_data(img_rows, img_cols):
# Load cifar10 training and validation sets
(X_train, Y_train), (X_valid, Y_valid) = cifar10.load_data()
# Resize trainging images
if K.image_dim_ordering() == 'th':
X_train = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_train[:nb_train_samples,:,:,:]])
X_valid = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_valid[:nb_valid_samples,:,:,:]])
else:
X_train = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_train[:nb_train_samples,:,:,:]])
X_valid = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_valid[:nb_valid_samples,:,:,:]])
# Transform targets to keras compatible format
Y_train = np_utils.to_categorical(Y_train[:nb_train_samples], num_classes)
Y_valid = np_utils.to_categorical(Y_valid[:nb_valid_samples], num_classes)
return X_train, Y_train, X_valid, Y_valid
但是下载数据集需要很长时间。相反,我下载了#cifar-10-python.tar.gz'手动。那么如何将其加载到变量(X_train,Y_train),(X_valid,Y_valid)而不是使用cifar10.load_data()?
答案 0 :(得分:0)
请原谅我的英语。我也试图手动加载cifar-10数据集。在以下代码中,我将cifar-10-python.tar.gz
解压缩到一个文件夹,并将文件夹中的文件data_batch_1
加载到4个数组中:x_train
,y_train
,x_test
,{{ 1}}。 20%的y_test
用于data_batch_1
和x_test
进行验证,其余用于y_test
和x_train
的培训。
y_train
答案 1 :(得分:0)
此处的代码从dataset website中所述的各个批处理文件中读取训练和测试图像,对this post进行了修改并给出了很好的解释。
import pickle
import numpy as np
for i in range(1,6):
path = 'data_batch_' + str(i)
with open(path, mode='rb') as file:
# note the encoding type is 'latin1'
batch = pickle.load(file, encoding='latin1')
if i == 1:
x_train = (batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)).astype('float32')
y_train = batch['labels']
else:
x_train_temp = (batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)).astype('float32')
y_train_temp = batch['labels']
x_train = np.concatenate((x_train,x_train_temp),axis = 0)
y_train = np.concatenate((y_train,y_train_temp),axis=0)
path = 'test_batch'
with open(path,'rb') as file:
# note the encoding type is 'latin1'
batch = pickle.load(file, encoding='latin1')
x_test = (batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)).astype('float32')
y_test = batch['labels']
我们可以将读取的数据可视化如下:
import matplotlib.pyplot as plt
x_train=x_train.astype(np.uint8)
y_train = np.expand_dims(y_train, axis = 1)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(np.squeeze(x_train[i]), cmap=plt.cm.binary)
# The CIFAR labels happen to be arrays,
# which is why you need the extra index
plt.xlabel(class_names[y_train[i][0]])
plt.show()
另外,see here可能是您唯一的问题,如果您仍然需要下载时间,则仍然可以使用load_data()
。