当src
具有形状[?]
时,tf.gather(src, tf.where(src != 0))
返回具有形状[?, 0]
的张量。我不确定尺寸的大小如何为0,并且我尤其不确定如何将张量改回来。我也没有在文档中找到任何可以解释这一点的东西。
我尝试进行tf.transpose(tensor)[0]
,但是转置张量的第一个维度的大小为0,无法访问!怎么了?
答案 0 :(得分:2)
我认为您应该使用tf.not_equal
对张量进行元素比较。
src = tf.constant([0, 1, 1, 0], dtype=tf.int8)
tf.gather(src, tf.where(tf.not_equal(src, 0))).eval(session=tf.Session())
array([[1],
[1]], dtype=int8)
您也可以将其缩短一点,并使用tf.boolean_mask
代替tf.where
和tf.gather
:
tf.boolean_mask(src, tf.not_equal(src, 0)).eval(session=tf.Session())
array([1, 1], dtype=int8)
请注意输出形状的差异。