我使用tensorflow2.0并制作一个磁带渐变程序,并使用@ tf.function获得一个函数。但是随着我的训练,尽管我只使用了550个单词的论文,但Mem仍在增长。我的数据总大小仅为30m,但内存使用量高达290G。此外,GPU使用率也在不断增加。当我完成一个纪元时,它告诉我Gpu内存不足。那么有人可以帮助我解决这个难题吗?
@tf.function(input_signature=train_step_signature)
def train_step(group, inp, tar, label):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:] # sess=tf.compat.v1.Session()
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
with tf.GradientTape(persistent=True) as tape:
classfication, predictions, _ = transformer(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
loss = loss_function(tar_real, predictions)
loss2 = tf.nn.softmax_cross_entropy_with_logits(label, classfication)
loss=loss+loss2
# print(loss,loss2)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
class_loss(loss2)
train_loss(loss)
train_accuracy(tar_real, predictions)
# gra = tape.gradient(loss2, transformer.trainable_variables)
# optimizer.apply_gradients(zip(gra,transformer.trainable_variables))
class_accuracy(tf.argmax(label, 1), classfication)`
我使用以下代码来训练tf.function:
tf.compat.v2.summary.trace_on(graph=True, profiler=True)
for epoch in range(EPOCHS):
start = time.time()
train_loss.reset_states()
train_accuracy.reset_states()
class_loss.reset_states()
class_accuracy.reset_states()
# inp -> portuguese, tar -> english
for (batch, (group, inp, tar, label)) in enumerate(train_dataset):
train_step(group, inp, tar, label)
if batch % 50 == 0:
print(
'Epoch {} Batch {} correct_Loss {:.4f} Correct_Accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(
epoch + 1, batch, train_loss.result(), train_accuracy.result(), class_accuracy.result(),
class_loss.result()))
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,
ckpt_save_path))
print('Epoch {} correct_Loss {:.4f} correct_Accuracy {:.4f} class_accurcay{:.4f} class_loss{:.4f}'.format(epoch + 1,
train_loss.result(),
train_accuracy.result(),
class_accuracy.result(),
class_loss.result()))
print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))