使用tf.GradientTape和MirroredStrategy和Dataset.from_generator进行训练

时间:2020-10-06 07:11:59

标签: tensorflow keras

我正在尝试结合tf.data.Dataset.from_generatortf.GradientTape来使用两个GPU训练网络。

我从官方的tensorflow仓库here找到了这个例子。在此示例中,它使用tf.data.Dataset类读取训练数据集。

另一方面,我的网络使用tf.keras.utils.Sequence类从磁盘读取训练图像,以同时扩充训练数据集。

问题是在给定相同的训练数据集的情况下,使用tf.GradientTape的MSE损失与使用model.fit的情况不同。

我尝试的代码如下。 我正在使用Tensorflow 2.3和Ubuntu 20.04

火车循环的类,

class Train(object):
def __init__(self, epochs, lr, steps_per_epoch, model, batch_size, patch_size, strategy, enable_function, optimizer, last_epoch):
    self.epochs = epochs
    self.batch_size = batch_size
    self.strategy = strategy
    self.optimizer = optimizer
    self.model = model
    self.enable_function = enable_function
    self.steps_per_epoch = steps_per_epoch
    self.patch_size = patch_size
    self.lr = lr
    self.last_epoch = last_epoch

def compute_loss(self, label, predictions):
    loss = tf.reduce_sum(self.loss_function(label, predictions)) * (1. / (self.batch_size))
    loss += (sum(self.model.losses) * 1. / self.strategy.num_replicas_in_sync)
    return loss
     
def loss_function(self, target, prediction):
    l2loss = tf.keras.losses.MSE(target, prediction)
     
    return l2loss * 1000

def train_step(self, inputs):
    """One train step.
    Args:
      inputs: one batch input.
    Returns:
      loss: Scaled loss.
    """

    image, target = inputs
    with tf.GradientTape() as tape:
      predictions = self.model(image)
      loss = self.compute_loss(target, predictions)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

    # self.train_acc_metric(target, predictions)

    return loss

def custom_loop(self, train_dist_dataset, strategy):  #test_dist_dataset,
    """Custom training and testing loop.
    Args:
      train_dist_dataset: Training dataset created using strategy.
      test_dist_dataset: Testing dataset created using strategy.
      strategy: Distribution strategy.
    Returns:
      train_loss, train_accuracy, test_loss, test_accuracy
    """

    def distributed_train_epoch(ds):
        total_loss = 0.0
        num_train_batches = 0.0
        for one_batch in ds:
            per_replica_loss = strategy.run(self.train_step, args=(one_batch,))
            total_loss += strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)
            num_train_batches += 1

        return total_loss, num_train_batches

    if self.enable_function:
        distributed_train_epoch = tf.function(distributed_train_epoch) 
     
    for epoch in range(self.epochs):
        
        print('start training1...')
        train_loss, num_train_batches_ = distributed_train_epoch(train_dist_dataset)

        template = ('Epoch: {}, Train Loss: {}') 
        print(template.format(epoch, train_loss))

        if np.mod(epoch+1, 2) == 0:
            checkpoint_dir = os.path.join(opt.ckpt_path, 'epoch_{}'.format(epoch + self.last_epoch + 1))
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            self.model.save(checkpoint_dir, save_format='tf')

    return train_loss / num_train_batches_

这是使用tf.GradientTape的主要功能,

def train_main():
    num_gpu = 2
    devices = ['/device:GPU:{}'.format(i) for i in range(num_gpu)]
    mirrored_strategy = tf.distribute.MirroredStrategy(devices)

  
    patch_size = opt.PatchSize
    batch_size = opt.BatchSize
    traindata_dir = opt.traindata_dir        
    data_augmentation = opt.data_augmentation
    traindata_loader = DatasetFromFolder(traindata_dir, batch_size, patch_size, data_augmentation=data_augmentation)        
    steps_per_epoch = 13817 // batch_size  # num data set

    def generator():
        multi_enqueuer = tf.keras.utils.OrderedEnqueuer(traindata_loader, use_multiprocessing=True)
        multi_enqueuer.start(workers=12, max_queue_size=1000)

        ''' generate the mini-batch training data for iteration '''
        for tt in range(int(steps_per_epoch)):
            input, target = next(multi_enqueuer.get())
            yield input, target

        # while True:
        #     input, target = next(multi_enqueuer.get())
        #     yield input, target

    dataset_train = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.float32),
                                                   output_shapes=(tf.TensorShape([None, None, None, 1]),
                                                                  tf.TensorShape([None, None, None, 1])))
   
    with mirrored_strategy.scope():

        checkpoint_dir = opt.ckpt_path
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        latest = os.listdir(checkpoint_dir)
        if len(latest) >= 1:
            model, last_epoch = restore_checkpoints(opt.ckpt_path, training=True)
            model.summary()
        else:
            print(' [*] No checkpoints for a pre-trained model. ')
            print(' [*] Build model. ')
            input = tf.keras.Input(shape=(None, None, 1))
            output = DenoisingNet(out_channel=1, base_filter=32, feat=128, activation='relu', bias=True, norm=None)(
                input)
            model = tf.keras.Model(inputs=input, outputs=output)
            last_epoch = 0
            model.summary()

        optimizer = tf.keras.optimizers.Adam(lr=opt.lr)
       
        train_dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset_train)

        trainer = Train(epochs=20, lr=opt.lr, steps_per_epoch=steps_per_epoch, model=model,
                        batch_size=batch_size,
                        patch_size=patch_size, strategy=mirrored_strategy,
                        enable_function=True,
                        optimizer=optimizer,
                        last_epoch=last_epoch)

        print('start training...')
        trainer.custom_loop(train_dist_dataset, mirrored_strategy)

和使用model.fit

的主要功能
    def generator():
        multi_enqueuer = tf.keras.utils.OrderedEnqueuer(traindata_loader, use_multiprocessing=True)
        multi_enqueuer.start(workers=12, max_queue_size=1000)
        while True:
            input, target = next(multi_enqueuer.get())
            yield input, target

    model.fit(generator(), epochs=opt.nEpochs, steps_per_epoch=steps_per_epoch, use_multiprocessing=False,
              workers=1, verbose=1,
              max_queue_size=1000,
              validation_data=generator_valid(),
              validation_steps=5)

tf.GradientTape的培训损失

Epoch: 0, Train Loss: 602256.9375
Epoch: 1, Train Loss: 247519.140625
Epoch: 2, Train Loss: 213128.125
Epoch: 3, Train Loss: 256719.359375

model.fit的培训损失

432/431 [==============================] - 24s 56ms/step - loss: 3.2862 - PSNR: 31.3674 - val_loss: 0.8055 - val_PSNR: 31.3003
432/431 [==============================] - 19s 44ms/step - loss: 0.2481 - PSNR: 36.7073 - val_loss: 0.4737 - val_PSNR: 33.4911
432/431 [==============================] - 18s 41ms/step - loss: 0.1609 - PSNR: 38.6651 - val_loss: 0.3542 - val_PSNR: 35.4089

tf.GradientTapemodel.fit相同的情况下使用while循环时,无限的while循环会冻结训练过程。 我猜想在使用generator()的主要函数中tf.GradientTape是正确的。

请给我一个解决这个问题的建议

非常感谢您的帮助。

0 个答案:

没有答案
相关问题