使用教程代码来计算使用检查点文件与冻结模型(.pb文件)节省的内存。理论(和afaik的实现)是检查点文件在训练模型期间有很多变量,而冻结模型将它们转换为常量(权重和偏差),因此内存消耗必须更低。但是,当我比较内存消耗时,差异大约只有50 MB(660 MB vs 610 MB)。我不知道,冻结模型的用途是什么?因为无论如何,在为模型提供服务时,内存中的大小与从检查点文件重新创建模型的方式并不完全不同。在我测量的地方下面发布一些代码..不确定代码是否可以给出整个图片,但它完全受https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/translate.py启发(解码方法)
请注意,下面的代码是我加载冻结图的版本,创建输入数据和直接调用运行。我还试图确定内存在哪一点上爆炸,并且它在session.run
之后的输出中显而易见。关于我还能做些什么来减少内存消耗的指示?我的最终模型将涉及至少一些这种性质的模型一起运行。万分感谢!
def decode_NER():
graph = load_graph( 'save/NER_SAVE/frozen_model.pb' )
for op in graph.get_operations():
print(op.name)
x = graph.get_tensor_by_name('prefix/encoder0:0')
z = graph.get_tensor_by_name('prefix/encoder1:0')
x1 = graph.get_tensor_by_name('prefix/encoder2:0')
z1 = graph.get_tensor_by_name('prefix/encoder3:0')
z2 = graph.get_tensor_by_name('prefix/encoder4:0')
z3 = graph.get_tensor_by_name('prefix/decoder0:0')
z4 = graph.get_tensor_by_name('prefix/decoder1:0')
z5 = graph.get_tensor_by_name('prefix/decoder2:0')
z6 = graph.get_tensor_by_name('prefix/decoder3:0')
z7 = graph.get_tensor_by_name('prefix/decoder4:0')
z8 = graph.get_tensor_by_name('prefix/decoder5:0')
z9 = graph.get_tensor_by_name('prefix/decoder6:0')
z10 = graph.get_tensor_by_name('prefix/decoder7:0')
z11 = graph.get_tensor_by_name('prefix/decoder8:0')
z12 = graph.get_tensor_by_name('prefix/decoder9:0')
y = [graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection/BiasAdd:0') ,graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_1/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_2/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_3/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_4/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_5/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_6/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_7/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_8/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_9/BiasAdd:0') ]
with tf.Session( graph=graph ) as sess:
...........
while sentence:
killer = sentence.split(' ')
from nltk import pos_tag, word_tokenize
vik = pos_tag( word_tokenize(sentence) )
poss = ''
for wd, tag in vik:
poss += tag+' '
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes( poss ), en_vocab)
print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
print( subprocess.run(['free', '-mh'], stdout=subprocess.PIPE) )
print( sys.getsizeof( sess ) )
# Which bucket does it belong to?
bucket_id = len(_buckets) - 1
for i, bucket in enumerate(_buckets):
if bucket[0] >= len(token_ids):
bucket_id = i
break
else:
logging.warning("Sentence truncated: %s", sentence)
# Get a 1-element batch to feed the sentence to the model.
encoder_inputs, decoder_inputs, target_weights = loc_get_batch(
{bucket_id: [(token_ids, [])]}, bucket_id)
print( encoder_inputs )
output_logits = sess.run(y, feed_dict={
x: encoder_inputs[0], z: encoder_inputs[1], x1: encoder_inputs[2], z1: encoder_inputs[3], z2: encoder_inputs[4], z3: decoder_inputs[0], z4: decoder_inputs[1], z5: decoder_inputs[2], z6: decoder_inputs[3], z7: decoder_inputs[4], z8: decoder_inputs[5], z9: decoder_inputs[6], z10: decoder_inputs[7], z11: decoder_inputs[8], z12: decoder_inputs[9]
} )
# This is a greedy decoder - outputs are just argmaxes of output_logits.
print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
print( subprocess.run(['free', '-mh'], stdout=subprocess.PIPE) )
print( sys.getsizeof( sess ) )
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
print(outputs)