tf.boolean_mask()在3d张量上带有2d遮罩

时间:2018-10-11 17:31:42

标签: python tensorflow

我有一个3d张量,形状如下(2,?,1)

t = np.array([[[1],[0],[1]],[[1],[0],[1]]])

我要创建一个为True的蒙版,其中t以上为非零

mask = t > 0

我还有另一个张量t1,其形状等于t,如下所示

t1 = np.array([[[2.1],[3.1],[1.2]],[[2.3],[1.9],[1.1]]])

我想过滤t1为非零的t。我正在使用tf.boolean_mask,但它为我提供了形状为(?,)的张量。我想要一个形状为(2,?,1)

的张量
t_m= tf.boolean_mask(t1, mask)  # [[1, 2], [5, 6]]
sess = tf.Session()
out = sess.run(t_m)

输出为

array([2.1, 1.2, 2.3, 1.1])

我想要类似的东西

array([[2.1,1.2],[2.3,1.1]])

0 个答案:

没有答案