我正在尝试从本地训练好的GCMLE预测模型下载嵌入,这样我就可以使用我自己的自定义嵌入可视化,这些可视化在张量板中不可用。我想将这些嵌入提取到一个很大的numpy矩阵中,但是我遇到了一些麻烦。我可以成功下载所有文件(saved_model.pb
+ assets/*
+ variables/*
,我似乎可以使用以下代码恢复模型:
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)
成功返回:
INFO:tensorflow:Restoring parameters from Servo/variables/variables
然后我尝试提取这样的权重:
constant_values = {}
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
成功输出了很多,但与嵌入相关的唯一部分是:
u'embedding_layer/embeddings/Initializer/random_uniform/max': 0.012765553,
u'embedding_layer/embeddings/Initializer/random_uniform/min': -0.012765553,
u'embedding_layer/embeddings/Initializer/random_uniform/shape': array([vocab_size, word_embedding_size], dtype=int32)
并没有实际嵌入权重的迹象。如何修改上面的方法以获得实际的嵌入权重矩阵?
答案 0 :(得分:1)
它将取决于您如何导出模型,但在大多数情况下,嵌入是变量而不是常量。所以你需要这样的东西:
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)
trainable_coll = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
vars = {v.name:sess.run(v.value()) for v in trainable_coll}