如何在TensorFlow中获取特定的张量行?

时间:2016-07-19 20:55:34

标签: python numpy tensorflow

我有几个张量:

logits:此张量包含最终预测分数。

tf.Tensor 'MemN2N_1/MatMul_3:0' shape=(?, 18230) dtype=float32

最终预测计算为predict_op = tf.argmax(logits,1,name =" predict_op")

现在我想将预测限制在某些特定的列中。以下两个张量包含我想要选择的列索引。

self._stories属于

类型
tf.Tensor 'stories:0' shape=(?, 12, 110) dtype=int32

self._queries的类型为

tf.Tensor 'queries:0' shape=(?, 110) dtype=int32

这里110列是我想限制登录的索引号。例如,如果logits = [[10,20,30,40,50],[10,20,30,40,50] ..]和self._stories = [[[1,4,...], [1,2,4,...],...],[[0,4,...],[2,4 ...],...] ...]和self._queries = [[1,4 ...],[2,4,...],...]然后logits应该像[[20,30,50],[10,30,50] ...] < / p>

如何在tensorflow中进行这种索引过滤?

3 个答案:

答案 0 :(得分:0)

你试过tf.equal吗?这比较了两个张量并创建了一个新的张量,其中包含True,其中相等,而False则不相等。

使用这个bool-tensor,你可以根据你在第一步中创建的bool值来提供tf.select,它从一个张量或另一个张量中选择。

没有深入了解您提供的特定形状,但是通过这两个操作,您可以创建您要求的流程。

答案 1 :(得分:0)

尝试tf.gather

row = tf.squeeze(row, squeeze_dims=0)

您可以使用tf.squeeze删除尺寸为1的主要尺寸:

id

答案 2 :(得分:0)

你可以使用 tf.slice(input_, begin, size, name=None)。 有关更多详细信息,请参阅 TensorFlow 文档: https://www.tensorflow.org/api_docs/python/tf/slice