加载Keras的自定义数据集

时间:2016-11-01 17:21:29

标签: python image-processing machine-learning keras

MNIST示例:

来自[neon example](http://neon.nervanasys.com/index.html/mnist.html):(类似于keras)

from neon.data import MNIST

 mnist = MNIST()

(X_train, y_train), (X_test, y_test), nclass = mnist.load_data()

我想为UCF_CC_50数据集获取相同的元组集。

这是一个由50个不同图像组成的数据集,是拥挤区域的鸟瞰图。 我正在修改the segment behind this

所有图像都已下载并包含在“图像”文件夹中。

这是 init

def __init__(self, filename, url, size, path='.', subset_pct=100):
    # parameters to use in dataset config serialization
    super(Dataset, self).__init__(name=None)
    self.filename = filename
    self.url = url
    self.size = size
    self.path = path
    self.subset_pct = subset_pct
    self._data_dict = None
    if subset_pct != 100:
        # placeholder to use partial data set
        raise NotImplemented('subset percentage feature is not yet implemented')

这是我到目前为止所拥有的。我不明白如何修改 init

class UCF(Dataset):
**def __init__(self, path='.', subset_pct=100, normalize=True):
    super(UCF, self).__init__('Images',
                                '//url',
                                15296311,
                                path=path,
                                subset_pct=subset_pct)**
    self.normalize = normalize

def load_data(self):
    filepath = self._valid_path_append(self.path, self.filename)

    with open(filepath, 'rb') as ucf:
        (X_train, y_train), (X_test, y_test) = pickle_load(ucf)
        X_train = X_train.reshape(-1, 784)
        X_test = X_test.reshape(-1, 784)

        if self.normalize:
            X_train = X_train / 255.
            X_test = X_test / 255.

    return (X_train, y_train), (X_test, y_test), 10

def gen_iterators(self):
    (X_train, y_train), (X_test, y_test), nclass = self.load_data()
    train = ArrayIterator(X_train,
                          y_train,
                          nclass=nclass,
                          lshape=(1, 28, 28),
                          name='train')
    val = ArrayIterator(X_test,
                        y_test,
                        nclass=nclass,
                        lshape=(1, 28, 28),
                        name='valid')
    self._data_dict = {'train': train,
                       'valid': val}
    return self._data_dict

任何人都可以帮我吗?

0 个答案:

没有答案