我想有效地从TFRecords
文件中检索示例,并使用keras ImageDataGenerator
类对其进行扩充,但是据我了解,ImageDataGenerator
只能采样numpy数组,pandas DataFrames或目录(由通过PIL可读的图像)。我还知道,可以使用tf.data.Dataset
将其转换为tf.data.Datset.from_generator()
对象,但不能将其转换为相反的方法。
也许我应该只使用ImageDataGenartor.flow_from_directory()
,但我认为它的速度较慢,尽管我并未使用大型数据集对其进行精确测量。如果大致相等,我将感谢一位消息人士指出。下一个代码显示了一个我希望自己的样子的示例:
from tf.keras.preprocessing.image import ImageDataGenerator
from tf.data import TFRecordDataset;
import tensorflow as tf;
from models import DenseNetBN100;
eps=150; batch_size=32; tot_examples=50000;
imre_shape=(32,32,3); lare_shape=(10,);
train_fname='cifar10_trainRecords';
model = DenseNetBN100(imre_shpae, lare_shape);
def _parse_record(proto, clip=False):
features = {
'image':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.string),
}
example = tf.parse_single_example(proto, features)
im = tf.decode_raw(example['image'], tf.float32)
im = tf.reshape(im, imre_shape)
la = tf.decode_raw(example['label'], tf.int8)
la = tf.reshape(la, lare_shape);
la = tf.cast(la, tf.float32);
return im, la;
dtst = TFRecordDataset(train_fname).map(_parse_record)
dtst=dtst.repeat(eps).shuffle(10000)
dtst=dtst.batch(batch_size)
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest');
train_generator = train_datagen.flow(dtst)
model.fit(train_generator, epochs=eps
steps_per_epoch=tot_examples//batch_size)
由于ImageDataGenerator
的管理,仍然有可能通过将tf.data.Dataset
转换为tf.data.Dataset
来提高效率。但这只是另一种理论,我认为拥有可信赖的资源比仅使用稀疏的个人测量方法会更好。