下面的代码是一个嵌套循环,用于在python 2.7中训练gru
;但这是一头耗油的公羊。 feats_tensor
和dec_padded_text
是太大的对象,同时加载它们会使我面临内存不足的错误。知道如何针对ram使用优化此代码吗?
for epoch in xrange(0, 13):
print ("Starting New Epoch: %d" % epoch)
np.random.shuffle(order)
del feats_tensor, dec_text_tensor
if cuda:
torch.cuda.empty_cache()
feats_tensor = torch.tensor(feats[order], requires_grad=False)
dec_text_tensor = torch.tensor(dec_padded_text[order], requires_grad=False)
if cuda:
feats_tensor = feats_tensor.cuda(device=device)
dec_text_tensor = dec_text_tensor.cuda(device=device)
for i in xrange(num_batches):
s = i * BATCH_SIZE
e = (i+1) * BATCH_SIZE
enc.zero_grad()
dec.zero_grad()
hid_enc = enc.forward(feats_tensor[s:e]).unsqueeze(0)
out_dec, hid_dec = dec.forward(dec_text_tensor[s:e,:-1], hid_enc)
out_perm = out_dec.permute(0, 2, 1)
loss = lossfunc(out_perm, dec_text_tensor[s:e,1:])
if sm_loss is None:
sm_loss = loss.data
else:
sm_loss = sm_loss*0.95 + 0.05*loss.data
loss.backward()
enc_optim.step()
dec_optim.step()
if i % 100 == 0:
print ("Epoch: %.3f" % (i/float(num_batches) + epoch,), "Loss:", sm_loss)
#print ("GEN:", untokenize(torch.argmax(out_dec,dim=2)[0,:], dec_idx_to_word))
#print ("GT:", untokenize(dec_text_tensor[s,:], dec_idx_to_word))
print ("--------------")
save_state(enc, dec, enc_optim, dec_optim, dec_idx_to_word, dec_word_to_idx, epoch)