Tensorflow cond不会在假分支上停止渐变

时间:2017-07-29 15:35:49

标签: tensorflow

我正在构建一个RNN模型,其中init_state可能来自两种情况之一。 1)从前一个时间步输出状态通过feed_dict输入的静态init_state。 2)变量的一些功能,我称之为得分。

init_state = cell.zero_state(batch,tf.float32)
with tf.name_scope('hidden1'):
     weights_h1 = tf.Variable(
                        tf.truncated_normal([T, cells_dim],
                        stddev=1.0 / np.sqrt(T)),
                        name='weights')
     biases_h1 = tf.Variable(tf.zeros([cells_dim]),
                        name='biases')
     hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)

init_state2 = tf.cond(is_start, lambda: hidden1, lambda: init_state)
然后将 init_state2用作static_rnn的输入,最终用于计算loss和train_op。当is_start为False时,我希望train_op对weights_h1没有影响。但是,每次更新后重量都会发生变化。非常感谢任何帮助。

1 个答案:

答案 0 :(得分:1)

这应该有效:

def return_init_state():
    init_state = cell.zero_state(batch,tf.float32)
    return init_state

def return_hidden_1():
    with tf.name_scope('hidden1'):
        weights_h1 = tf.Variable(
                            tf.truncated_normal([T, cells_dim],
                            stddev=1.0 / np.sqrt(T)),
                            name='weights')
        biases_h1 = tf.Variable(tf.zeros([cells_dim]),
                            name='biases')
        hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)

        return hidden1

init_state2 = tf.cond(is_start, lambda: return_hidden_1, lambda: return_init_state)

注意如何在tf.cond的上下文中调用这些方法。因此,无论创建什么op,都将在tf.cond的上下文中。否则,根据您的情况,操作将以两种方式运行。