我有一个可以在渴望模式下运行的RNN,但不能在其他模式下运行,我想了解原因。
RNN应该接受数字序列并输出其累计和。这是代码:
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
init_state_getter = lambda shape, dtype: np.ones(shape).astype(np.float32) # was ones
rnn = tf.keras.layers.SimpleRNN(
1, # 1-dimensional state
stateful=True,
return_sequences=True,
activation=None,
use_bias=False,
kernel_initializer = init_state_getter,
recurrent_initializer = init_state_getter,
return_state=False
)
x = np.ones([1,1,1]).astype(np.float32)
for _ in range(5):
print(rnn(x))
这可行,我们得到输出
tf.Tensor([[[1.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[2.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[3.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[4.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[5.]]], shape=(1, 1, 1), dtype=float32)
现在让我们针对不急切的模式进行修改:
import tensorflow as tf
import numpy as np
init_state_getter = lambda shape, dtype: np.ones(shape).astype(np.float32)
with tf.Session() as sess:
x = tf.placeholder(tf.float32, shape=(1,1,1))
rnn = tf.keras.layers.SimpleRNN(
1,
stateful=True,
return_sequences=True,
activation=None,
use_bias=False,
kernel_initializer = init_state_getter,
recurrent_initializer = init_state_getter,
return_state=False
)
output = rnn(x)
init = tf.global_variables_initializer()
sess.run(init)
for _ in range(5):
print(sess.run(output, feed_dict={x:np.ones([1,1,1]).astype(np.float32)}))
这一次,即使我从未重新初始化,RNN似乎仍会丢失其状态:
[[[1.]]]
[[[1.]]]
[[[1.]]]
[[[1.]]]
[[[1.]]]
现在我想我可以将旧状态作为feed_dict的一部分显式地馈送到RNN中,但是(1)考虑到我们已经设置了stateful = True,我很惊讶我需要这样做,(2)我不知道如果使用估算器该如何解决。
答案 0 :(得分:0)
与此相关的问题是您在RNN调用期间没有传递初始状态。修改代码的方法如下:
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
x = tf.placeholder(tf.float32, shape=(1,1,1))
rnn = tf.keras.layers.SimpleRNN(
1,
stateful=True,
return_sequences=True,
activation=None,
use_bias=False,
kernel_initializer = 'ones',
recurrent_initializer = 'ones',
return_state=True
)
current_state = np.zeros((1,1))
state_placeholder = tf.placeholder(tf.float32, shape=[1, 1])
output, state = rnn(x, initial_state=state_placeholder) # You need to explicitly pass the state in the next step
init = tf.global_variables_initializer()
sess.run(init)
for _ in range(5):
op_val, state_val = sess.run([output, state], feed_dict={x:np.ones([1,1,1]).astype(np.float32),
state_placeholder: current_state.astype(np.float32)})
current_state = state_val
print(op_val)
# print(sess.run(output, feed_dict={x:np.ones([1,1,1]).astype(np.float32)}))