重用一部分张量流训练图

时间:2016-12-06 15:35:07

标签: serialization tensorflow

所以,我训练了一个几层的张量流模型,或多或少像这样:

with tf.variable_scope('model1') as scope:

    inputs = tf.placeholder(tf.int32, [None, num_time_steps])
    embeddings = tf.get_variable('embeddings', (vocab_size, embedding_size))
    lstm = tf.nn.rnn_cell.LSTMCell(lstm_units)

    embedded = tf.nn.embedding_lookup(embeddings, inputs)
    _, state = tf.nn.dynamic_rnn(lstm, embedded, dtype=tf.float32, scope=scope)

    # more stuff on the state

现在,我想在另一个模型中重用嵌入矩阵和lstm权重,除了这两个组件之外,它与这个非常不同。

据我所知,如果我用tf.Saver对象加载它们,它会寻找 具有完全相同名称的变量,但我在两个图中使用了不同的variable_scope

this answer中,建议创建图表,其中LSTM被训练为另一个的超集,但我不认为在我的情况下这是可能的,因为在两种型号。无论如何,如果他们做独立的事情,我认为让一个图依赖于另一个图是个好主意。

我考虑过在序列化图中更改LSTM权重和嵌入的变量范围。我的意思是,在最初阅读model1/Weights:0或其他内容的地方,它将是another_scope/Weights:0。是否可行?

当然,如果有更好的解决方案,也欢迎。

1 个答案:

答案 0 :(得分:1)

我发现可以使用字典将序列化文件中的变量名称(没有尾随:0)映射到我想要在图中恢复的变量对象来初始化Saver。例如:

varmap = {'model1/some_scope/weights': variable_in_model2,
          'model1/another_scope/weights': another_variable_in_model2}

saver = tf.train.Saver(varmap)
saver.restore(sess, path_to_saved_file)