tensorflow1.10 lstm + ctc内存泄漏

时间:2018-09-20 06:55:36

标签: python tensorflow memory-leaks lstm

问:张量流的内存泄漏存在问题。该程序为 正在关注

def train(train_batch_iter,val_batch_iter,train_sum_num,val_sum_num):
 24     process = psutil.Process(getpid())
 25     global_step=tf.Variable(0,trainable=False)
 32     boundaries=[10*steps_per_batch,50*steps_per_batch]
 33     values_lr=[1e-2,1e-3,1e-4]
 34     lr=tf.train.piecewise_constant(global_step,boundaries,values_lr)
 35     logits, inputs, targets, seq_len,W, b=model.get_train_model()
 36     sparse_targets=utils.dense_to_sparse(targets)
 37     loss=tf.nn.ctc_loss(labels=sparse_targets,inputs=logits,sequence_length=seq_len)
 38     cost=tf.reduce_mean(loss)
 39 
 40     optimizer=tf.train.AdamOptimizer(learning_rate=lr).minimize(loss,global_step=global_step)
 41     decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, beam_width=cfg.beam_width,top_paths=cfg.top_paths,merge_repeated=False)
 42     output_dense=tf.sparse_to_dense(decoded[0].indices,decoded[0].dense_shape,decoded[0].values)
 43 
 44     edit_dis=tf.edit_distance(tf.cast(decoded[0],tf.int32),sparse_targets)
 45     dis=tf.reduce_mean(edit_dis)#得到译码与标签(都是稀疏矩阵形式)之间的平均编辑距离作为准确率
 47     init=tf.global_variables_initializer()
 50     with tf.device("/gpu:%s"%cfg.gpu_id):
 51         with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:

 54             print('Train on %d samples,val on %d samples.'%(train_sum_num,val_sum_num))
 55             sess.run(init)
 61             early_stop_flag=False
 62             best_val_cost=sys.maxsize
 63             best_val_epoch=0
 66             for epoch in range(cfg.epoch_num):
68                 if not early_stop_flag:
 69                     train_cost=0
 71                     for batch in range(train_sum_num//cfg.batch_size):
 72 
 73                         before = process.memory_percent()
 74                         train_batch_inputs,train_batch_targets,train_batch_seq_len = next(train_batch_iter)
 75                         feed_dict={inputs: train_batch_inputs, targets: train_batch_targets, seq_len: train_batch_seq_len}
 76                         batch_loss,batch_targets,batch_logits,batch_seq_len,batch_cost,step,_=sess.run([loss,targets,logits,seq_len,cost,global_step,opti    mizer],feed_dict=feed_dict)
 77 
 78                         train_cost+=batch_cost*cfg.batch_size
 80                         del  train_batch_inputs,train_batch_targets,train_batch_seq_len
 81                         del  batch_loss,batch_targets,batch_logits,batch_seq_len
 82                         after = process.memory_percent()
 83                         print("Batch = %d, Memory CHANGE %.4f -> %.4f"%(batch, before, after))   

内存更改的结果如下:

Epoch 0: 
  61 Batch = 1, Memory CHANGE 3.5693 -> 3.5828
  62 Batch = 2, Memory CHANGE 3.5828 -> 3.5917
                     ... ...
3171 Batch = 3111, Memory CHANGE 3.8037 -> 3.8039
3172 Batch = 3112, Memory CHANGE 3.8039 -> 3.8039
3173 Batch = 3113, Memory CHANGE 3.8039 -> 3.8034

Epoch 1: 
3186 Batch = 0, Memory CHANGE 3.8321 -> 3.8324
3187 Batch = 1, Memory CHANGE 3.8324 -> 3.8325
            ... ...
4089 Batch = 903, Memory CHANGE 3.8560 -> 3.8558
4090 Batch = 904, Memory CHANGE 3.8558 -> 3.8563 

不幸的是,我提到了一些没有用的方法。如果您能解决问题,我将不胜感激。

0 个答案:

没有答案