我正在尝试选择不同于零的元素,并在以后使用它们。我的输入张量具有批次大小,因此我想保留它并且不要在批次中混合数据。我认为tf.gather_nd()
对我有用,但是首先我必须获取所需数据的索引,然后我发现tf.where()
。我尝试了以下方法:
img = tf.constant([[[1., 0., 0.],
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]
indexes = tf.where(tf.not_equal(img, 0.))
我希望indexes
保持批次尺寸,但是形状为[7, 2]
。我怀疑问题出在不同批次中具有满足条件的不同点数。
有没有一种方法可以使索引保持批次尺寸?预先感谢。
编辑: indexes
的形状为[7, 3]
,其中第一个暗指点数,第二个暗指点的位置(包括它属于哪个批次) )。但是我需要indexes
具有特定的批次维,因为稍后我想用它来收集img
中的数据:
Y = tf.gather_nd(img, indexes)
我希望Y
具有批处理维度,但是由于indexes
没有,我得到了一个混合了来自不同bateches数据的平坦张量。
答案 0 :(得分:0)
实际上,您可能做错了一些:运行代码时,indexes
的尺寸为(7,3)
,而不是(7,2)
。 3
对应于您的3个维度,而7
对应于img
中非零元素的数量。
sess.run(indexes)
的完整结果:
array([[0, 0, 0],
[0, 1, 2],
[0, 2, 1],
[1, 0, 0],
[1, 0, 1],
[1, 0, 2],
[1, 1, 2]])