Tensorflow:使用tf.where()时如何保持批次尺寸?

时间:2019-09-13 09:37:45

标签: python tensorflow

我正在尝试选择不同于零的元素,并在以后使用它们。我的输入张量具有批次大小,因此我想保留它并且不要在批次中混合数据。我认为tf.gather_nd()对我有用,但是首先我必须获取所需数据的索引,然后我发现tf.where()。我尝试了以下方法:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))

我希望indexes保持批次尺寸,但是形状为[7, 2]。我怀疑问题出在不同批次中具有满足条件的不同点数。

有没有一种方法可以使索引保持批次尺寸?预先感谢。

编辑: indexes的形状为[7, 3],其中第一个暗指点数,第二个暗指点的位置(包括它属于哪个批次) )。但是我需要indexes具有特定的批次维,因为稍后我想用它来收集img中的数据:

Y = tf.gather_nd(img, indexes)

我希望Y具有批处理维度,但是由于indexes没有,我得到了一个混合了来自不同bateches数据的平坦张量。

1 个答案:

答案 0 :(得分:0)

实际上,您可能做错了一些:运行代码时,indexes的尺寸为(7,3),而不是(7,2)3对应于您的3个维度,而7对应于img中非零元素的数量。

sess.run(indexes)的完整结果:

array([[0, 0, 0],
      [0, 1, 2],
      [0, 2, 1],
      [1, 0, 0],
      [1, 0, 1],
      [1, 0, 2],
      [1, 1, 2]])