冻结OpenAI GPT2模型的张量流图

时间:2019-06-16 15:36:36

标签: tensorflow

试图冻结GPT2精细调整的模型,但无法确定输出节点名称是什么。使用this代码作为参考,我将其组合在一起:-

import fire
import json
import os
import numpy as np
import tensorflow as tf

import model, sample, encoder

seed=None
length=40
temperature=1
top_k=0

hparams = model.default_hparams()
with open('models/345M/hparams.json') as f:
  hparams.override_from_dict(json.load(f))

with tf.Session(graph=tf.Graph()) as sess:
  context = tf.placeholder(tf.int32, [1, None])
  np.random.seed(seed)
  tf.set_random_seed(seed)
  output = sample.sample_sequence(
      hparams=hparams, length=length,
      context=context,
      batch_size=1,
      temperature=temperature, top_k=top_k
  )

  saver = tf.train.Saver()
  ckpt = tf.train.latest_checkpoint(os.path.join('models', '345M'))
  saver.restore(sess, ckpt)

  print([n.name for n in tf.get_default_graph().as_graph_def().node])

  # Freeze the graph
  frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,[output.name])

  # Save the frozen graph
  with open('output_graph.pb', 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())

但我知道

  

AssertionError:sample_sequence / while / Exit_3:0不在图中

那么我应该将什么作为参数3输出节点名称放在Frozen_graph中?

0 个答案:

没有答案
相关问题