减少内存消耗tensorflow冻结模型

时间:2017-10-10 05:35:04

标签: python memory tensorflow tensorflow-serving

使用教程代码来计算使用检查点文件与冻结模型(.pb文件)节省的内存。理论(和afaik的实现)是检查点文件在训练模型期间有很多变量,而冻结模型将它们转换为常量(权重和偏差),因此内存消耗必须更低。但是,当我比较内存消耗时,差异大约只有50 MB(660 MB vs 610 MB)。我不知道,冻结模型的用途是什么?因为无论如何,在为模型提供服务时,内存中的大小与从检查点文件重新创建模型的方式并不完全不同。在我测量的地方下面发布一些代码..不确定代码是否可以给出整个图片,但它完全受https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/translate.py启发(解码方法)

请注意,下面的代码是我加载冻结图的版本,创建输入数据和直接调用运行。我还试图确定内存在哪一点上爆炸,并且它在session.run之后的输出中显而易见。关于我还能做些什么来减少内存消耗的指示?我的最终模型将涉及至少一些这种性质的模型一起运行。万分感谢!

def decode_NER():
    graph = load_graph( 'save/NER_SAVE/frozen_model.pb' )
    for op in graph.get_operations():
       print(op.name)
    x = graph.get_tensor_by_name('prefix/encoder0:0')
    z = graph.get_tensor_by_name('prefix/encoder1:0')
    x1 = graph.get_tensor_by_name('prefix/encoder2:0')
    z1 = graph.get_tensor_by_name('prefix/encoder3:0')
    z2 = graph.get_tensor_by_name('prefix/encoder4:0')
    z3 = graph.get_tensor_by_name('prefix/decoder0:0')
    z4 = graph.get_tensor_by_name('prefix/decoder1:0')
    z5 = graph.get_tensor_by_name('prefix/decoder2:0')
    z6 = graph.get_tensor_by_name('prefix/decoder3:0')
    z7 = graph.get_tensor_by_name('prefix/decoder4:0')
    z8 = graph.get_tensor_by_name('prefix/decoder5:0')
    z9 = graph.get_tensor_by_name('prefix/decoder6:0')
    z10 = graph.get_tensor_by_name('prefix/decoder7:0')
    z11 = graph.get_tensor_by_name('prefix/decoder8:0')
    z12 = graph.get_tensor_by_name('prefix/decoder9:0')
    y = [graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection/BiasAdd:0') ,graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_1/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_2/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_3/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_4/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_5/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_6/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_7/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_8/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_9/BiasAdd:0') ]
   with tf.Session( graph=graph  ) as sess:
    ...........
       while sentence:
           killer = sentence.split(' ')
           from nltk import pos_tag, word_tokenize
           vik = pos_tag( word_tokenize(sentence)  )
           poss = ''
           for wd, tag in vik:
               poss += tag+' '
           token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes( poss ), en_vocab)
           print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
           print( subprocess.run(['free', '-mh'], stdout=subprocess.PIPE) )
           print( sys.getsizeof( sess ) )
  # Which bucket does it belong to?
           bucket_id = len(_buckets) - 1
           for i, bucket in enumerate(_buckets):
               if bucket[0] >= len(token_ids):
                   bucket_id = i
               break
               else:
                   logging.warning("Sentence truncated: %s", sentence)

  # Get a 1-element batch to feed the sentence to the model.
           encoder_inputs, decoder_inputs, target_weights = loc_get_batch(
                  {bucket_id: [(token_ids, [])]}, bucket_id)

           print( encoder_inputs )
           output_logits = sess.run(y,  feed_dict={
        x: encoder_inputs[0], z: encoder_inputs[1], x1: encoder_inputs[2], z1: encoder_inputs[3], z2: encoder_inputs[4], z3: decoder_inputs[0], z4: decoder_inputs[1], z5: decoder_inputs[2], z6: decoder_inputs[3], z7: decoder_inputs[4], z8: decoder_inputs[5], z9: decoder_inputs[6], z10: decoder_inputs[7], z11: decoder_inputs[8], z12: decoder_inputs[9]
           } )
  # This is a greedy decoder - outputs are just argmaxes of output_logits.
           print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
           print( subprocess.run(['free', '-mh'], stdout=subprocess.PIPE) )
           print( sys.getsizeof( sess ) )
           outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
           print(outputs)

0 个答案:

没有答案