tf.function使用了我所有的CPU内存并崩溃

时间:2019-11-21 04:50:48

标签: memory-leaks tensorflow2.0

在我的程序中,我用@ tf.function装饰定义了一个train_step:

@tf.function(input_signature=train_step_signature)
def train_step(group, inp, tar, label):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]  # sess=tf.compat.v1.Session()
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        classfication, predictions, _ = transformer(inp, tar_inp,
                                                    True,
                                                    enc_padding_mask,
                                                    combined_mask,
                                                    dec_padding_mask)
        loss = loss_function(tar_real, predictions)
        loss2 = tf.nn.softmax_cross_entropy_with_logits(label, classfication)
        loss=loss+loss2

    # print(loss,loss2)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
    class_loss(loss2)
    train_loss(loss)
    train_accuracy(tar_real, predictions)

    # gra = tape.gradient(loss2, transformer.trainable_variables)
    # optimizer.apply_gradients(zip(gra,transformer.trainable_variables))
    class_accuracy(tf.argmax(label, 1), classfication)

但是当我使用如下循环时:

tf.compat.v2.summary.trace_on(graph=True, profiler=True)
for epoch in range(EPOCHS):
    start = time.time()
    train_loss.reset_states()
    train_accuracy.reset_states()
    class_loss.reset_states()
    class_accuracy.reset_states()
    # inp -> portuguese, tar -> english
    for (batch, (group, inp, tar, label)) in enumerate(train_dataset):

        train_step(group, inp, tar, label)
        if batch % 50 == 0:
            print(
                'Epoch {} Batch {} correct_Loss {:.4f} Correct_Accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(
                    epoch + 1, batch, train_loss.result(), train_accuracy.result(), class_accuracy.result(),
                    class_loss.result()))

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,
                                                            ckpt_save_path))

    print('Epoch {} correct_Loss {:.4f} correct_Accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(epoch + 1,

                                                                                                              train_loss.result(),
                                                                                                              train_accuracy.result(),
                                                                                                              class_accuracy.result(),
                                                                                                              class_loss.result()))

    print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
    with summary_writer.as_default():
        tf.compat.v2.summary.scalar('train_class_loss', class_loss.result(), step=epoch)
        tf.compat.v2.summary.scalar('train_translate_loss', class_loss.result(), step=epoch)
        tf.compat.v2.summary.scalar('train_class_accuracy', class_accuracy.result(), step=epoch)
        tf.compat.v2.summary.scalar('train_translage_accruacy', train_accuracy.result(), step=epoch)
    train_loss.reset_states()
    train_accuracy.reset_states()
    class_loss.reset_states()
    class_accuracy.reset_states()
    if epoch % 10 == 0:
        for (batch, (group, inp, tar, label)) in enumerate(val_dataset):
            tar_inp = tar[:, :-1]
            tar_real = tar[:, 1:]  # sess=tf.compat.v1.Session()
            enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
            classfication, predictions, _ = transformer(inp, tar_inp,
                                                        True,
                                                        enc_padding_mask,
                                                        combined_mask,
                                                        dec_padding_mask)
            loss = loss_function(tar_real, predictions)
            loss2 = tf.nn.softmax_cross_entropy_with_logits(label, classfication)
            class_accuracy(tf.argmax(label, 1), classfication)
            class_loss(loss2)
            train_loss(loss)
            train_accuracy(tar_real, predictions)
        with summary_writer.as_default():
            tf.compat.v2.summary.scalar('test_class_loss', class_loss.result(), step=epoch)
            tf.compat.v2.summary.scalar('test_translate_loss', class_loss.result(), step=epoch)
            tf.compat.v2.summary.scalar('test_class_accuracy', class_accuracy.result(), step=epoch)
            tf.compat.v2.summary.scalar('test_translage_accruacy', train_accuracy.result(), step=epoch)
        print('Eval {} correct_Loss {:.4f} correct_Accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(
            epoch + 1,
            train_loss.result(),
            train_accuracy.result(),
            class_accuracy.result(),
            class_loss.result()))

GPU内存大约为8000M,而CPU内存从大约1G增长到超过30G,这使我的程序崩溃了。这个问题困扰了我很长时间,并花了我很多钱租用GPU进行测试。有人可以帮我解决我的问题吗? 我认为问题是由我的转换器中的reshap函数引起的,如下所示:

re_encou = tf.reshape(enc_output, (-1, MAX_LENGTH * d_model))

这是真的吗? 我不知道如何解决问题。

0 个答案:

没有答案