渴望模式与非渴望模式下的有状态性

时间:2019-08-25 22:32:13

标签: python tensorflow

我有一个可以在渴望模式下运行的RNN,但不能在其他模式下运行,我想了解原因。

RNN应该接受数字序列并输出其累计和。这是代码:

import tensorflow as tf
import numpy as np
tf.enable_eager_execution()

init_state_getter = lambda shape, dtype: np.ones(shape).astype(np.float32) # was ones

rnn = tf.keras.layers.SimpleRNN(
    1, # 1-dimensional state
    stateful=True,   
    return_sequences=True,
    activation=None, 
    use_bias=False,
    kernel_initializer = init_state_getter,
    recurrent_initializer = init_state_getter,
    return_state=False
)

x = np.ones([1,1,1]).astype(np.float32)

for _ in range(5):
    print(rnn(x))

这可行,我们得到输出

tf.Tensor([[[1.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[2.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[3.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[4.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[5.]]], shape=(1, 1, 1), dtype=float32)

现在让我们针对不急切的模式进行修改:

import tensorflow as tf
import numpy as np

init_state_getter = lambda shape, dtype: np.ones(shape).astype(np.float32)

with tf.Session() as sess:

    x = tf.placeholder(tf.float32, shape=(1,1,1))

    rnn = tf.keras.layers.SimpleRNN(
        1, 
        stateful=True,   
        return_sequences=True,
        activation=None, 
        use_bias=False,
        kernel_initializer = init_state_getter,
        recurrent_initializer = init_state_getter,
        return_state=False
    )
    output = rnn(x)
    init = tf.global_variables_initializer()

    sess.run(init)


    for _ in range(5):
        print(sess.run(output, feed_dict={x:np.ones([1,1,1]).astype(np.float32)}))

这一次,即使我从未重新初始化,RNN似乎仍会丢失其状态:

[[[1.]]]
[[[1.]]]
[[[1.]]]
[[[1.]]]
[[[1.]]]

现在我想我可以将旧状态作为feed_dict的一部分显式地馈送到RNN中,但是(1)考虑到我们已经设置了stateful = True,我很惊讶我需要这样做,(2)我不知道如果使用估算器该如何解决。

1 个答案:

答案 0 :(得分:0)

与此相关的问题是您在RNN调用期间没有传递初始状态。修改代码的方法如下:

import tensorflow as tf
import numpy as np

with tf.Session() as sess:

    x = tf.placeholder(tf.float32, shape=(1,1,1))

    rnn = tf.keras.layers.SimpleRNN(
        1, 
        stateful=True,   
        return_sequences=True,
        activation=None, 
        use_bias=False,
        kernel_initializer = 'ones',
        recurrent_initializer = 'ones',
        return_state=True 
    )

    current_state = np.zeros((1,1))
    state_placeholder = tf.placeholder(tf.float32, shape=[1, 1])
    output, state = rnn(x, initial_state=state_placeholder) # You need to explicitly pass the state in the next step 
    init = tf.global_variables_initializer()
    sess.run(init)

    for _ in range(5):
        op_val, state_val = sess.run([output, state], feed_dict={x:np.ones([1,1,1]).astype(np.float32),
                                                                 state_placeholder: current_state.astype(np.float32)})
        current_state = state_val
        print(op_val)
        # print(sess.run(output, feed_dict={x:np.ones([1,1,1]).astype(np.float32)}))