动态RNN填充和索引以匹配基础事实

时间:2017-07-21 12:36:49

标签: python indexing tensorflow tensor

我正在运行RNN(多对多)。我为每个用户设置了不同的时间维度。例如。用户1有5个时间戳,用户2有8个时间戳等。我认为RNN只接受具有恒定尺寸的张量,所以我目前用零填充(不使用tensorflow)时间维度,直到达到最大时间戳批处理中的所有用户(max_user_time)。例如,如果用户#1有2个时间戳和3个特征,我得到一个尺寸为[1,2,3]的张量:

|1 2|
|2 5|
|6 3|

如果批次中的用户3有3个时间戳,那么我们需要添加paading,以便用户1具有尺寸为[1,3,3]的张量:

|1 2 0|
|2 5 0|
|6 3 0|

每个用户的填充长度不同。

有没有办法一次性为所有用户使用tf.pad或类似内容?

填充后,我将这些张量作为输入传递给RNN并重新整形输出:

outputs,states=tf.nn.dynamic_rnn(lstm_cell,inputs=input
                                     ,dtype=tf.float32,sequence_length=max_batch)

reshape_out=tf.reshape(outputs,[-1,n_hidden])

对于序列长度参数,我传递一个带有每个用户的时间戳的向量,因此如果用户已超过其最大时间,则期望零输出 - 根据tf.dynamic_rnn文档。

所以,我来自重塑张量大小[batch_sizeXmax_user_time,n_hidden]。

这是一个比地面实况测量器更大的张量,它更小,并且根据其时间戳为每个用户提供行。

是否有一种简单的方法可以使用张量流来仅选择观察到的行来计算损失?

0 个答案:

没有答案