使用tf.function时tensorflow2.0用尽了GPU内存

时间:2019-11-20 14:08:05

标签: out-of-memory tensorflow2.0

我使用tensorflow2.0并制作一个磁带渐变程序,并使用@ tf.function获得一个函数。但是随着我的训练,尽管我只使用了550个单词的论文,但Mem仍在增长。我的数据总大小仅为30m,但内存使用量高达290G。此外,GPU使用率也在不断增加。当我完成一个纪元时,它告诉我Gpu内存不足。那么有人可以帮助我解决这个难题吗?

@tf.function(input_signature=train_step_signature)
def train_step(group, inp, tar, label):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]  # sess=tf.compat.v1.Session()
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        classfication, predictions, _ = transformer(inp, tar_inp,
                                                    True,
                                                    enc_padding_mask,
                                                    combined_mask,
                                                    dec_padding_mask)
        loss = loss_function(tar_real, predictions)
        loss2 = tf.nn.softmax_cross_entropy_with_logits(label, classfication)
        loss=loss+loss2

    # print(loss,loss2)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
    class_loss(loss2)
    train_loss(loss)
    train_accuracy(tar_real, predictions)

    # gra = tape.gradient(loss2, transformer.trainable_variables)
    # optimizer.apply_gradients(zip(gra,transformer.trainable_variables))
    class_accuracy(tf.argmax(label, 1), classfication)`

我使用以下代码来训练tf.function:

tf.compat.v2.summary.trace_on(graph=True, profiler=True)
for epoch in range(EPOCHS):
    start = time.time()
    train_loss.reset_states()
    train_accuracy.reset_states()
    class_loss.reset_states()
    class_accuracy.reset_states()
    # inp -> portuguese, tar -> english
    for (batch, (group, inp, tar, label)) in enumerate(train_dataset):

        train_step(group, inp, tar, label)
        if batch % 50 == 0:
            print(
                'Epoch {} Batch {} correct_Loss {:.4f} Correct_Accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(
                    epoch + 1, batch, train_loss.result(), train_accuracy.result(), class_accuracy.result(),
                    class_loss.result()))

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,
                                                            ckpt_save_path))

    print('Epoch {} correct_Loss {:.4f} correct_Accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(epoch + 1,

                                                                                                              train_loss.result(),
                                                                                                              train_accuracy.result(),
                                                                                                              class_accuracy.result(),
                                                                                                              class_loss.result()))

    print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

0 个答案:

没有答案