目标是使用生成器从目录中获取数据来进行线程安全模型训练。
我用Dataset包装了ImageDataGenerator.flow_from_directory(),然后交织了多个实例。
#%% Create data generators
from random_eraser import get_random_eraser
def get_gens(size=299,
bs_per_gpu=128,
eraser=False,
get_classes=False):
bs = bs_per_gpu * num_gpus
func_eraser = None
if eraser:
func_eraser = get_random_eraser(p=0.5,
s_l=0.01,
s_h=0.05,
r_1=0.3,
r_2=1/0.3,
pixel_level=True)
train_datagen = ImageDataGenerator(preprocessing_function=func_eraser,
rescale=1/255.,
width_shift_range=0.2,
height_shift_range=0.2,
fill_mode='constant',
cval=0.0,
horizontal_flip=True,
rotation_range=45,
brightness_range=[0.5,1.5],
zoom_range=[0.8,1.0],
validation_split=0.3)
val_datagen = ImageDataGenerator(rescale=1/255.,
validation_split=0.3)
test_datagen = ImageDataGenerator(rescale=1/255.)
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=img_shape,
batch_size=bs,
seed=1337,
subset='training')
val_generator = val_datagen.flow_from_directory(train_dir,
target_size=img_shape,
batch_size=bs,
shuffle=False,
seed=1337,
subset='validation')
test_generator = test_datagen.flow_from_directory(test_dir,
target_size=img_shape,
batch_size=bs,
shuffle=False,
classes=['test'])
if get_classes:
return train_generator.classes, val_generator.classes
def multithread_gen(gen, cores):
# set up tf generator
Dataset = tf.data.Dataset
ds = Dataset.from_tensor_slices([str(x) for x in range(cores)])
ds = ds.interleave(lambda x: Dataset.from_generator(gen,
output_types=(tf.float32, tf.float32)),
cycle_length=cores,
block_length=1,
num_parallel_calls=cores)
#ds.prefetch(buffer_size=AUTOTUNE) # or 10?
return ds
train_generator = multithread_gen(lambda: train_generator,
cores=31)
val_generator = multithread_gen(lambda: val_generator,
cores=31)
test_generator = multithread_gen(lambda: test_generator,
cores=31)
return train_generator, val_generator, test_generator
#train_generator, val_generator, test_generator = get_gens(size, bs_per_gpu)
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
然后我用以下方法训练模型:
def train_model(epochs):
model_history = model.fit(train_generator, # add val_generator?
epochs=epochs,
callbacks=callbacks,
steps_per_epoch=int(num_train/bs),
#class_weight=weights_dict,
workers=31,
use_multiprocessing=True)
return model_history
但是在打印以下内容后模型被卡住了:
第1/20集
INFO:tensorflow:batch_all_reduce:156个全减少算法= nccl,num_packs = 1 INFO:tensorflow:batch_all_reduce:156个全减少算法= nccl,num_packs = 1
一直以来,CPU内核和GPU几乎没有使用,为0%。