我对我的RNN实施有疑问。
我有以下代码
def one_step(x_t, h_tm1, W_ih, W_hh, b_h, W_ho, b_o):
h_t = T.tanh(
theano.dot(x_t, W_ih) +
theano.dot(h_tm1, W_hh) +
b_h
)
y_t = theano.dot(h_t, W_ho) + b_o
return [h_t, y_t]
n_hid = 3
n_in = 1
n_out = 1
W_hh_values = np.array(np.random.uniform(size=(n_hid, n_hid), low=-.01, high=.01), dtype=dtype)
W_hh2_values = np.array(np.random.uniform(size=(n_hid, n_hid), low=-.01, high=.01), dtype=dtype)
h0_value = np.array(np.random.uniform(size=(n_hid), low=-.01, high=.01), dtype=dtype)
b_h_value = np.array(np.random.uniform(size=(n_hid), low=-.01, high=.01), dtype=dtype)
b_h2_value = np.array(np.random.uniform(size=(n_hid), low=-.01, high=.01), dtype=dtype)
W_ih_values = np.array(np.random.uniform(size=(n_in, n_hid), low=-.01, high=.01), dtype=dtype)
W_ho_values = np.array(np.random.uniform(size=(n_hid, n_out), low=-.01, high=.01), dtype=dtype)
b_o_value = np.array(np.random.uniform(size=(n_out), low=-.01, high=.01), dtype=dtype)
# parameters of the rnn
b_h = theano.shared(b_h_value)
b_h2 = theano.shared(b_h_value)
h0 = theano.shared(h0_value)
W_ih = theano.shared(W_ih_values)
W_hh = theano.shared(W_hh_values)
W_hh2 = theano.shared(W_hh_values)
W_ho = theano.shared(W_ho_values)
b_o = theano.shared(b_o_value)
params = [W_ih, W_hh, b_h, W_ho, b_o, h0]
# target values
t = T.matrix(dtype=dtype)
# hidden and outputs of the entire sequence
[h_vals, y_vals], _ = theano.scan(fn=one_step,
sequences = dict(input = x, taps=10),
outputs_info = [h0, None], # corresponds to the return type of one_step
non_sequences = [W_ih, W_hh, b_h, W_ho, b_o]
)
learn_rnn_fn = theano.function([],
outputs = cost,
updates = updates,
givens = {
x: s_,
t: t_
}
)
现在经过培训后,我可以预测当然的输出:
test_rnn_fn = theano.function([],
outputs = y_vals,
givens = {x: s_2}
)
然而,这是以预测模式运行网络(即采用X步输入并预测输出)。我想以生成模式运行它,这意味着我想从一个初始状态开始,让RNN运行任意数量的步骤并将其输出作为输入反馈。
我怎么能这样做?
谢谢!
答案 0 :(得分:3)
您可以使用n_steps
参数以任意数量的步骤运行扫描。要将y_t
传递给下一个计算 - 并假设x_t.shape == y_t.shape
- 您可以使用outputs_info=[h0, x_t]
作为参数来扫描并修改步骤函数为one_step(h_tm1, y_tm1, W_ih, W_hh, b_h, W_ho, b_o)
可以使用theano.scan_module.until()
创建更复杂的终止条件,但这些是设计未实现的问题。有关示例,请参阅here。