TensorFlow RNN:设置可训练标志

时间:2017-10-21 07:06:36

标签: python tensorflow lstm rnn

我正在构建一个模型,需要在培训之前复制网络,因此有一个“旧”和“新”网络。培训仅在新网络上执行,旧网络是静态的。根据两个网络的不同来阻止大量更新(参见https://arxiv.org/abs/1707.06347

,裁剪培训更新的幅度

tf.layers中,很容易设置trainable标志,如下所示:

def _build_cnet(self, name, trainable):
    w_reg = tf.contrib.layers.l2_regularizer(L2_REG)

    with tf.variable_scope(name):
        l1 = tf.layers.dense(self.state, 400, tf.nn.relu, trainable=trainable,
                             kernel_regularizer=w_reg, name="vf_l1")
        l2 = tf.layers.dense(l1, 400, tf.nn.relu, trainable=trainable, kernel_regularizer=w_reg, name="vf_l2")
        vf = tf.layers.dense(l2, 1, trainable=trainable, kernel_regularizer=w_reg, name="vf_out")
    params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
    return vf, params

我正在尝试在网络末尾添加LSTM图层,如下所示:

def _build_cnet(self, name, trainable):
    w_reg = tf.contrib.layers.l2_regularizer(L2_REG)

    with tf.variable_scope(name):
        c_lstm = tf.contrib.rnn.BasicLSTMCell(CELL_SIZE)
        self.c_init_state = c_lstm.zero_state(batch_size=1, dtype=tf.float32)

        l1 = tf.layers.dense(self.state, 400, tf.nn.relu, trainable=trainable,
                             kernel_regularizer=w_reg, name="vf_l1")
        l2 = tf.layers.dense(l1, 400, tf.nn.relu, trainable=trainable, kernel_regularizer=w_reg, name="vf_l2")

        # LSTM layer
        c_outputs, self.c_final_state = tf.nn.dynamic_rnn(cell=c_lstm, inputs=tf.expand_dims(l2, axis=0),
                                                          initial_state=self.c_init_state)
        c_cell_out = tf.reshape(c_outputs, [-1, CELL_SIZE], name='flatten_lstm_outputs')

        vf = tf.layers.dense(c_cell_out, 1, trainable=trainable, kernel_regularizer=w_reg, name="vf_out")
    params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
    return vf, params

是否可以轻松地向trainabletf.contrib.rnn.BasicLSTMCell添加tf.nn.dynamic_rnn标记?

似乎RNNCell有一个trainable标志,但BasicLSTMCell没有?

0 个答案:

没有答案