在Keras的RNN中返回所有州的时间步长

时间:2017-05-30 22:02:33

标签: python tensorflow keras recurrent-neural-network

我正在Keras中实现我自己的循环图层,在step函数内部,我希望能够跨越所有时间步骤访问隐藏状态,而不仅仅是默认情况下的最后一个状态,以便我可以做一些事情,比如及时向后添加跳过连接。

我正在尝试修改tensorflow后端_step内的K.rnn以返回到目前为止所有隐藏状态。我最初的想法是简单地将每个隐藏状态存储到TensorArray,然后将所有这些状态传递给step_function(即我的图层中的step函数)。我当前修改的函数如下,它将每个隐藏状态写入TensorArray states_ta_t

   def _step(time, output_ta_t, states_ta_t, *states):
            current_input = input_ta.read(time)
            # Here I'd like to return all states up to current time
            # and pass to step_function, instead of just the last
            states = [states_ta_t.read(time)]
            output, new_states = step_function(current_input,
                                               tuple(states) +
                                               tuple(constants))
            for state, new_state in zip(states, new_states):
                new_state.set_shape(state.get_shape())
            states_ta_t = states_ta_t.write(time+1, new_states[0]) # record states
            output_ta_t = output_ta_t.write(time, output)
            return (time + 1, output_ta_t, states_ta_t) + tuple(new_states) 

此版本仅返回最后一个状态,就像原始实现一样,并且作为普通RNN工作。如何获取目前为止的所有状态,存储在数组中并传递给step_function?感觉这应该非常简单,但是我不太熟悉使用TensorArrays ......

(注意:这在展开的版本中比在符号版本中更容易,但不幸的是,我的实验使用展开的版本会耗尽内存)

1 个答案:

答案 0 :(得分:2)

- 已编辑 -

我发现我误解了你的问题,我非常抱歉...

简而言之,试试这个:

states = states_ta_t.stack()[:time]

以下是一些解释:您确实将所有这些状态存储在states_ta_t中,但您只将最后一个状态传递给step_function

您在代码中所做的是:

# Param 'time' refers to 'current time step'
states = [states_ta_t.read(time)]

这意味着您正在阅读当前的'从states_ta_t开始,换句话说,就是最后一个州。

如果你想做一些切片,也许stack函数会有所帮助。例如:

states = states_ta_t.stack()[:time]

但是我不确定这是否是一个正确的实现,因为我不熟悉TensorArray ...

希望它有所帮助!如果不是,如果您愿意发表评论并与我讨论,我感到非常荣幸!