获取满足条件张量流的张量行

时间:2018-09-02 07:01:11

标签: numpy tensorflow

input_mb = tf.placeholder(tf.int32, [None, 166, 1], name="input_minibatch")

让我们说上面的代码。我想获取上述minibatch张量的行,以使每个检索到的行的第一个元素== a。我如何在Tensorflow中做到这一点?另外,您如何在Numpy中做到这一点?

2 个答案:

答案 0 :(得分:0)

(给出一个值a)

要在 numpy 中实现这一目标,您只需编写:

selected_rows = myarray[myarray[:,0]== a]

tensorflow 中,使用tf.where:

mytensor[tf.squeeze(tf.where(tf.equal(mytensor[:,0],a), None, None))

答案 1 :(得分:0)

我会在tensorflow上这样做:

tf.gather(mytensor, tf.squeeze(tf.where(tf.equal(mytensor[:,0],a), None, None)), axis=0)