我正在Keras中实现我自己的循环图层,在step
函数内部,我希望能够跨越所有时间步骤访问隐藏状态,而不仅仅是默认情况下的最后一个状态,以便我可以做一些事情,比如及时向后添加跳过连接。
我正在尝试修改tensorflow后端_step
内的K.rnn
以返回到目前为止所有隐藏状态。我最初的想法是简单地将每个隐藏状态存储到TensorArray,然后将所有这些状态传递给step_function
(即我的图层中的step
函数)。我当前修改的函数如下,它将每个隐藏状态写入TensorArray states_ta_t
:
def _step(time, output_ta_t, states_ta_t, *states):
current_input = input_ta.read(time)
# Here I'd like to return all states up to current time
# and pass to step_function, instead of just the last
states = [states_ta_t.read(time)]
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
for state, new_state in zip(states, new_states):
new_state.set_shape(state.get_shape())
states_ta_t = states_ta_t.write(time+1, new_states[0]) # record states
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t, states_ta_t) + tuple(new_states)
此版本仅返回最后一个状态,就像原始实现一样,并且作为普通RNN工作。如何获取目前为止的所有状态,存储在数组中并传递给step_function
?感觉这应该非常简单,但是我不太熟悉使用TensorArrays ......
(注意:这在展开的版本中比在符号版本中更容易,但不幸的是,我的实验使用展开的版本会耗尽内存)
答案 0 :(得分:2)
- 已编辑 -
我发现我误解了你的问题,我非常抱歉...
简而言之,试试这个:
states = states_ta_t.stack()[:time]
以下是一些解释:您确实将所有这些状态存储在states_ta_t
中,但您只将最后一个状态传递给step_function
。
您在代码中所做的是:
# Param 'time' refers to 'current time step'
states = [states_ta_t.read(time)]
这意味着您正在阅读当前的'从states_ta_t
开始,换句话说,就是最后一个州。
如果你想做一些切片,也许stack
函数会有所帮助。例如:
states = states_ta_t.stack()[:time]
但是我不确定这是否是一个正确的实现,因为我不熟悉TensorArray ...
希望它有所帮助!如果不是,如果您愿意发表评论并与我讨论,我感到非常荣幸!