收集动态索引列表

时间:2017-03-16 23:08:19

标签: numpy tensorflow

任何人都可以帮忙弄清楚如何抓取一组嵌入物吗?

我有一些代码可以预测每个索引的概率,然后选择max:

# U is batch_size x max_sentence_length x embedding_size
scores_per_index = find_start_preds(U ...) # batch_size x max_sentence_length x 1
start_preds = tf.argmax(alpha, axis=1) # batch_size x 1

如果可能的话,我想重新抓取与每个开始预测相关联的嵌入词。那可能吗?这就是我的想法,但它不起作用:(

u_s = U[:, start_preds, :]

1 个答案:

答案 0 :(得分:1)

你应该可以使用tf.gather来获得你想要的东西,但它只适用于领先的索引,所以你需要重新排序:

U2 = tf.transpose(U, [1, 0, 2])
u_s = tf.gather(U2, start_preds)
u_s = tf.transpose(u_s, [1, 0, 2])