如何在Tensorflow中重用权重?

时间:2017-04-28 23:46:01

标签: python tensorflow deep-learning

我会多次调用下面的tensorflow代码。是否重复使用权重或每次创建新图表?

    def lstm(encoder_cell, encoder_inputs_embedded, encoder_inputs_length):
        with tf.variable_scope('lstm') as scope_bilstm:
            ((encoder_fw_outputs,
              encoder_bw_outputs),
             (encoder_fw_state,
              encoder_bw_state)) = (
                tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                                cell_bw=encoder_cell,
                                                inputs=encoder_inputs_embedded,
                                                sequence_length=encoder_inputs_length,
                                                time_major=False,
                                                dtype=tf.float32)
                )

        encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)

        return encoder_outputs

由于我知道它可能无法重复使用,因此我在reuse=True中添加了额外的tf.variable_scope(),尝试了以下代码。

    def lstm(encoder_cell, encoder_inputs_embedded, encoder_inputs_length):
        with tf.variable_scope('lstm', reuse=True) as scope_bilstm:
            ((encoder_fw_outputs,
              encoder_bw_outputs),
             (encoder_fw_state,
              encoder_bw_state)) = (
                tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                                cell_bw=encoder_cell,
                                                inputs=encoder_inputs_embedded,
                                                sequence_length=encoder_inputs_length,
                                                time_major=False,
                                                dtype=tf.float32)
                )

        encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)

        return encoder_outputs

但是我收到了以下错误,

  

ValueError:变量inner / bidirectional_rnn / fw / lstm_cell / weights确实如此   不存在,或者不是用tf.get_variable()创建的。你的意思是   在VarScope中设置reuse = None?

如何解决此错误?我真的很感激!

scope_bilstm.reuse_variables()怎么样?我不知道在我的程序中插入该行的位置。

1 个答案:

答案 0 :(得分:0)

可能有用,请查看load_with_skip函数

中的this