如何将keras ImageDataGenerator类应用于TFRecordsDataset以进行扩展?

时间:2019-06-09 20:34:56

标签: python tensorflow keras

我想有效地从TFRecords文件中检索示例,并使用keras ImageDataGenerator类对其进行扩充,但是据我了解,ImageDataGenerator只能采样numpy数组,pandas DataFrames或目录(由通过PIL可读的图像)。我还知道,可以使用tf.data.Dataset将其转换为tf.data.Datset.from_generator()对象,但不能将其转换为相反的方法。

也许我应该只使用ImageDataGenartor.flow_from_directory(),但我认为它的速度较慢,尽管我并未使用大型数据集对其进行精确测量。如果大致相等,我将感谢一位消息人士指出。下一个代码显示了一个我希望自己的样子的示例:

from tf.keras.preprocessing.image import ImageDataGenerator
from tf.data import TFRecordDataset;
import tensorflow as tf;
from models import DenseNetBN100;

eps=150; batch_size=32; tot_examples=50000;
imre_shape=(32,32,3); lare_shape=(10,);
train_fname='cifar10_trainRecords';
model = DenseNetBN100(imre_shpae, lare_shape);

def _parse_record(proto, clip=False):
    features = {
      'image':tf.FixedLenFeature([],tf.string),
      'label':tf.FixedLenFeature([],tf.string),
    }    
    example = tf.parse_single_example(proto, features)
    im = tf.decode_raw(example['image'], tf.float32)
    im = tf.reshape(im, imre_shape)
    la = tf.decode_raw(example['label'], tf.int8)
    la = tf.reshape(la, lare_shape);
    la = tf.cast(la, tf.float32);
    return im, la;
dtst = TFRecordDataset(train_fname).map(_parse_record)
dtst=dtst.repeat(eps).shuffle(10000)
dtst=dtst.batch(batch_size)

train_datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest');

train_generator = train_datagen.flow(dtst)
model.fit(train_generator, epochs=eps
    steps_per_epoch=tot_examples//batch_size)

由于ImageDataGenerator的管理,仍然有可能通过将tf.data.Dataset转换为tf.data.Dataset来提高效率。但这只是另一种理论,我认为拥有可信赖的资源比仅使用稀疏的个人测量方法会更好。

0 个答案:

没有答案