将tf.data.Dataset与model.fit交错使用时,模型训练陷入困境

时间:2020-06-25 21:10:21

标签: python multithreading tensorflow multiprocessing conv-neural-network

目标是使用生成器从目录中获取数据来进行线程安全模型训练。

我用Dataset包装了ImageDataGenerator.flow_from_directory(),然后交织了多个实例。

#%% Create data generators

from random_eraser import get_random_eraser

def get_gens(size=299,
             bs_per_gpu=128,
             eraser=False,
             get_classes=False):
    
    bs = bs_per_gpu * num_gpus 
    
    func_eraser = None
    if eraser:
        func_eraser = get_random_eraser(p=0.5,
                                        s_l=0.01,
                                        s_h=0.05,
                                        r_1=0.3,
                                        r_2=1/0.3,
                                        pixel_level=True)

    train_datagen = ImageDataGenerator(preprocessing_function=func_eraser,
                                       rescale=1/255.,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       fill_mode='constant',
                                       cval=0.0,
                                       horizontal_flip=True,
                                       rotation_range=45,
                                       brightness_range=[0.5,1.5],
                                       zoom_range=[0.8,1.0],
                                       validation_split=0.3)

    val_datagen = ImageDataGenerator(rescale=1/255.,
                                     validation_split=0.3)

    test_datagen = ImageDataGenerator(rescale=1/255.)

    train_generator = train_datagen.flow_from_directory(train_dir,
                                                        target_size=img_shape,
                                                        batch_size=bs,
                                                        seed=1337,
                                                        subset='training')

    val_generator = val_datagen.flow_from_directory(train_dir,
                                                    target_size=img_shape,
                                                    batch_size=bs,
                                                    shuffle=False,
                                                    seed=1337,
                                                    subset='validation')

    test_generator = test_datagen.flow_from_directory(test_dir,
                                                      target_size=img_shape,
                                                      batch_size=bs,
                                                      shuffle=False,
                                                      classes=['test'])
    
    if get_classes:
        return train_generator.classes, val_generator.classes
    
    def multithread_gen(gen, cores):
        
        # set up tf generator
        Dataset = tf.data.Dataset
        ds = Dataset.from_tensor_slices([str(x) for x in range(cores)])
        ds = ds.interleave(lambda x: Dataset.from_generator(gen,
                                                            output_types=(tf.float32, tf.float32)),
                           cycle_length=cores,
                           block_length=1,
                           num_parallel_calls=cores)
        #ds.prefetch(buffer_size=AUTOTUNE) # or 10?
        return ds
    
    train_generator = multithread_gen(lambda: train_generator,
                                      cores=31) 
    
    val_generator = multithread_gen(lambda: val_generator,
                                    cores=31)

    test_generator = multithread_gen(lambda: test_generator,
                                     cores=31)
    
    return train_generator, val_generator, test_generator

#train_generator, val_generator, test_generator = get_gens(size, bs_per_gpu)

# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

然后我用以下方法训练模型:

def train_model(epochs):

    model_history = model.fit(train_generator, # add val_generator?
                              epochs=epochs,
                              callbacks=callbacks,
                              steps_per_epoch=int(num_train/bs),
                              #class_weight=weights_dict,
                              workers=31,
                              use_multiprocessing=True)
    
    return model_history

但是在打印以下内容后模型被卡住了:

第1/20集

INFO:tensorflow:batch_all_reduce:156个全减少算法= nccl,num_packs = 1 INFO:tensorflow:batch_all_reduce:156个全减少算法= nccl,num_packs = 1

一直以来,CPU内核和GPU几乎没有使用,为0%。

0 个答案:

没有答案