为什么数据生成器在这个简单的代码(python yield)中不合理地慢?

时间:2018-04-07 08:24:23

标签: python python-3.x deep-learning keras yield

假设我有一个简单的数据加载器,如下所示。它加载了大约50MB的pickle文件。 一次通话需要约39.5毫秒。

def loader(image_id):
    cache_path = os.path.join(CACHE_DIR, 'train', '{:05}.pickle')
    with open(cache_path.format(image_id), 'rb') as f:
        item = pickle.load(f)
        mask = item['contour']
    return mask

%timeit loader(0)
%timeit loader(200)

### --- output ----
### 39.4 ms ± 442 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
### 40.3 ms ± 733 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

所以,我创建了一个生成器,这个生成器运行上面的函数。批量大小(迭代次数)在这里仅为3。

def data_generator(dataset, config, batch_size=1):
    b = 0  # batch item index
    image_index = -1

    while True:
        image_index = (image_index + 1) % 2000
        gt_mask = loader(image_index)

        b += 1
        if b >= batch_size:         # Batch full?
            inputs = []
            outputs = []
            yield inputs, outputs   # dummy return
            b = 0                   # start a new batch

然而,计算时间变化很大,达到 4.6秒。怎么会发生这种情况?

train_generator = data_generator(dataset_train, config, batch_size=3)
%timeit inputs, outputs = next(train_generator)
%timeit inputs, outputs = next(train_generator)
%timeit inputs, outputs = next(train_generator)
%timeit inputs, outputs = next(train_generator)

### --- output ----
### 4.63 s ± 137 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
### 4.44 s ± 100 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
### 4.32 s ± 73.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
### 4.25 s ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

如果是由I / O命令引起的,如何解决?

0 个答案:

没有答案