我有一个3d张量,形状如下(2,?,1)
t = np.array([[[1],[0],[1]],[[1],[0],[1]]])
我要创建一个为True的蒙版,其中t
以上为非零
mask = t > 0
我还有另一个张量t1
,其形状等于t
,如下所示
t1 = np.array([[[2.1],[3.1],[1.2]],[[2.3],[1.9],[1.1]]])
我想过滤t1
为非零的t
。我正在使用tf.boolean_mask
,但它为我提供了形状为(?,)
的张量。我想要一个形状为(2,?,1)
t_m= tf.boolean_mask(t1, mask) # [[1, 2], [5, 6]]
sess = tf.Session()
out = sess.run(t_m)
输出为
array([2.1, 1.2, 2.3, 1.1])
我想要类似的东西
array([[2.1,1.2],[2.3,1.1]])