Keras' flow_from_directory'非常慢

时间:2018-01-18 21:21:30

标签: python-3.x machine-learning keras

我有一个图像分类项目,我正在研究它有17,255个图像和49个类别。这是一个概念验证练习。实际的最终产品将涉及100,000到500,000张图像。考虑到大量图像及其大小,我决定调查Keras' flow_from_directory'功能。

当我最初针对整个图像集运行下面的代码时,它运行了一个多小时而没有完成。为了解决这个问题,我创建了一个图像和目录类别的子集。对于一百个左右的图像,脚本在大约30秒内完成。

当我将内容增加到大约1,400张图像时,脚本需要30多分钟才能完成。对于我的数据集,这将是每小时2,800张图片或超过6小时(随意检查我的数学)。这只是数据生成部分,不包括任何实际培训

我在拥有8个CPU和50 Gig RAM的Google实例上运行。脚本运行时的CPU和内存使用量很小,因此硬件不是问题。

机器规格:

instance-4 > uname -a
Linux instance-4 4.4.0-109-generic #132~14.04.1-Ubuntu SMP Tue Jan 9 21:46:42 UTC 2018 x86_64 x86_64 x86_64 GNU/Linux

Python规范:

>>> print(keras.__version__)
2.1.2
>>> import tensorflow as tf
>>> print(tf.__version__)
1.4.1
>>> 
instance-4 > python -V
Python 3.6.3 :: Anaconda, Inc.

文件存储是Google Cloud。示例目录是:

instance-4 > ls -1 ./data/val
cat1
cat2
cat3
cat4
cat5
cat6

在每个目录/类别中都是指向实际图像文件的符号链接(也在Google Cloud上)。

我想到链接可能是问题所在,但是当我运行一百个左右的图像文件时,其性能与符号链接(约30秒)大致相同

所以我的问题是:我做错了什么,或者是Keras' flow_from_directory'只是无法处理大量图像(尽管有广告/文档)?

示例代码:

#!/usr/bin/env python

import warnings

#... Supress TensorFlow warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    import keras

from keras.preprocessing.image import ImageDataGenerator

from datetime import datetime
import time

test_datagen = ImageDataGenerator()
validation_dir = './data/val'

start_time = time.time()
print( str(datetime.now()) )

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(100,100),
        batch_size=32,
        class_mode='categorical',
        follow_links=True
)

print(validation_generator)

print("--- %s seconds ---" % (time.time() - start_time))

1 个答案:

答案 0 :(得分:1)

问题解决了。我将所有内容从Google Cloud复制到我的实例上的磁盘上,整个图像集的生成器运行时间不到2秒。