Tensorflow:使用一个张量来索引另一个

时间:2018-05-16 02:05:56

标签: tensorflow

作为这个问题的动机,我尝试使用tf.nn.dynamic_rnn的可变长度序列。当我使用batch_size=1(一次一个元素)进行训练时,一切都在游动,但现在我试图增加批量大小,这意味着零填充序列的长度相同。

我对所有序列进行了零填充(或截断),最大长度为15000.

outputs(来自RNN)的形状为[batch_size, max_seq_length, num_units],具体就是[16, 15000, 64]

我还创建了一个seq_lengths张量,[batch_size],所以[16],对应于所有零填充序列的实际序列长度。

我添加了一个完全连接的图层,将以前的outputs[:,-1,:]乘以W,然后添加一个偏差项,因为最终我只想尝试预测单个值(或者更确切地说) batch_size值)。但是,现在,我不能天真地使用-1作为索引,因为所有的序列都被不同地填充了!我有seq_lengths,但我不确定如何使用它来索引outputs。我已经四处搜索了,我认为答案是对tf.gather_nd的一些巧妙使用,但我无法弄明白。我可以很容易地看到如何获取单个值,但我想保留整个切片。我是否需要创建某种巨大的3D蒙版?

这就是我想要的Python理解(输出是np.array):outputs = np.array([outputs[i, seq_lengths[i], :] for i in range(batch_size)])

我很感激任何帮助!谢谢。

1 个答案:

答案 0 :(得分:0)

实际上,Alex发现你已经回答了我的问题:)。

经过一些更多的研究,我发现了以下内容,这正是我的用例:https://stackoverflow.com/a/43298689/5526865。我不会在这里复制代码,但只是检查一下。