如何通过目录添加数据来训练图像

时间:2018-01-03 18:26:08

标签: python git numpy keras

我一直在通过flyyufelix" https://github.com/flyyufelix/cnn_finetune"浏览git资源库。微调初始v3网络我想训练网络检测疾病所以我有2套图像一个疾病,没有疾病。 git说X_train,Y_train,X_valid,Y_valid = load_data()他加载了cifar数据集,git要求我们创建自己的load_data()函数。作者的代码如下

import cv2
import numpy as np

from keras.datasets import cifar10
from keras import backend as K
from keras.utils import np_utils

nb_train_samples = 3000 # 3000 training samples
nb_valid_samples = 100 # 100 validation samples
num_classes = 10

def load_cifar10_data(img_rows, img_cols):

    # Load cifar10 training and validation sets
    (X_train, Y_train), (X_valid, Y_valid) = cifar10.load_data()

    # Resize trainging images
    if K.image_dim_ordering() == 'th':
        X_train = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_train[:nb_train_samples,:,:,:]])
        X_valid = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_valid[:nb_valid_samples,:,:,:]])
    else:
        X_train = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_train[:nb_train_samples,:,:,:]])
        X_valid = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_valid[:nb_valid_samples,:,:,:]])

    # Transform targets to keras compatible format
    Y_train = np_utils.to_categorical(Y_train[:nb_train_samples], num_classes)
    Y_valid = np_utils.to_categorical(Y_valid[:nb_valid_samples],num_classes)
return X_train, Y_train, X_valid, Y_valid

我可以知道如何生成一个加载的函数  数据X_train,Y_train,X_valid,Y_valid = load_data()当我在pc中有目标时

2 个答案:

答案 0 :(得分:0)

使用Keras' ImageDataGenerator()课程并在其上调用flow_from_directory()。标签将从目录名称自动推断出来。所以,如果你有一个名为"疾病的目录,"然后Keras会推断该目录中的所有图像都被标记为"疾病,"对于标题为“没有疾病”的另一个目录,同样的事情也是如此。"例如。

我演示如何在this video中准备用于在Keras中训练CNN的图像数据。视频的前半部分是关于磁盘上的图像组织,然后下半部分将完成上述过程。

答案 1 :(得分:0)