试图冻结GPT2精细调整的模型,但无法确定输出节点名称是什么。使用this代码作为参考,我将其组合在一起:-
import fire
import json
import os
import numpy as np
import tensorflow as tf
import model, sample, encoder
seed=None
length=40
temperature=1
top_k=0
hparams = model.default_hparams()
with open('models/345M/hparams.json') as f:
hparams.override_from_dict(json.load(f))
with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [1, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=1,
temperature=temperature, top_k=top_k
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', '345M'))
saver.restore(sess, ckpt)
print([n.name for n in tf.get_default_graph().as_graph_def().node])
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,[output.name])
# Save the frozen graph
with open('output_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
但我知道
AssertionError:sample_sequence / while / Exit_3:0不在图中
那么我应该将什么作为参数3输出节点名称放在Frozen_graph中?