从GCMLE保存的模型中提取嵌入

时间:2018-03-15 01:24:00

标签: python tensorflow google-cloud-ml

我正在尝试从本地训练好的GCMLE预测模型下载嵌入,这样我就可以使用我自己的自定义嵌入可视化,这些可视化在张量板中不可用。我想将这些嵌入提取到一个很大的numpy矩阵中,但是我遇到了一些麻烦。我可以成功下载所有文件(saved_model.pb + assets/* + variables/*,我似乎可以使用以下代码恢复模型:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)

成功返回:

INFO:tensorflow:Restoring parameters from Servo/variables/variables

然后我尝试提取这样的权重:

constant_values = {}

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)

    constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
    for constant_op in constant_ops:
        constant_values[constant_op.name] = sess.run(constant_op.outputs[0])

成功输出了很多,但与嵌入相关的唯一部分是:

u'embedding_layer/embeddings/Initializer/random_uniform/max': 0.012765553,
u'embedding_layer/embeddings/Initializer/random_uniform/min': -0.012765553,
u'embedding_layer/embeddings/Initializer/random_uniform/shape': array([vocab_size, word_embedding_size], dtype=int32)

并没有实际嵌入权重的迹象。如何修改上面的方法以获得实际的嵌入权重矩阵?

1 个答案:

答案 0 :(得分:1)

它将取决于您如何导出模型,但在大多数情况下,嵌入是变量而不是常量。所以你需要这样的东西:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)

    trainable_coll = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    vars = {v.name:sess.run(v.value()) for v in trainable_coll}