tf.boolean_mask`

时间:2018-08-06 23:31:00

标签: tensorflow

我想知道在tf.boolean_mask之后,如何从pred(一维张量)和mask(具有'True'或' False”),其中False默认为0

labels = np.random.rand(256, 256)
heatmaps = np.random.rand(256, 256, 10)

mask = labels > 0.5
heatmaps = tf.boolean_mask(heatmaps, mask)

scores = tf.nn.softmax(logits=heatmaps)
pred = tf.argmax(scores, axis=1)

顺便说一句,使用numpy易于实现:

pred_reshape = np.zeros((256, 256))
pred_ = sess.run(pred)
pred_reshape[mask] = pred_

但是,它需要将tensor转换回numpy数组。

1 个答案:

答案 0 :(得分:1)

也许您可以在计算乘以蒙版之后进行蒙版。结果将是相同的,并且您将保持形状(256,256)。

import tensorflow as tf
import numpy as np

labels = np.random.rand(256,256)
mask = labels > 0.5
heatmaps = np.random.rand(256, 256, 10)

scores = tf.nn.softmax(logits=heatmaps)
pred = tf.argmax(scores, axis=2) # Note that now the 10-element axis is 2 (256,256,10)

# Masking
pred_reshape = tf.multiply(pred, mask)

缺点是对要屏蔽的数据进行了不必要的计算。

请注意,现在pred_reshapeint而不是float的数组。

作为支票:

import tensorflow as tf
import numpy as np

labels = np.random.rand(256,256)
mask = labels > 0.5
heatmaps = np.random.rand(256, 256, 10)

sess = tf.Session()

# Original code
original_heatmaps = tf.boolean_mask(heatmaps, mask)
original_scores = tf.nn.softmax(logits=original_heatmaps)
original_pred = tf.argmax(original_scores, axis=1)

original_pred_reshape = np.zeros((256,256))
original_pred_ = sess.run(original_pred)
original_pred_reshape[mask] = original_pred_

# New code
new_scores = tf.nn.softmax(logits=heatmaps)
new_pred = tf.argmax(new_scores, axis=2)
new_pred_reshape = tf.multiply(new_pred, mask)
new_result = sess.run(new_pred_reshape)

print('All elements equal:', np.all(original_pred_reshape==new_result))

请注意,argmax返回最高元素的索引,在某些情况下为0。因此,如果需要,您将无法将这些元素与被屏蔽的元素区分开:两种情况都在前一种情况下pred_reshape = np.zeros (( 256, 256)),就像第二个pred_reshape = tf.multiply (pred, mask)一样。

如果您需要区分被遮罩的元素,也许可以执行以下操作:

import tensorflow as tf
import numpy as np

labels = np.random.rand(256,256)
mask = labels > 0.5
heatmaps = np.random.rand(256, 256, 10)

scores = tf.nn.softmax(logits=heatmaps)
pred = tf.argmax(scores, axis=2) # Note that now the 10-element axis is 2 (256,256,10)

# Masking
pred_reshape = tf.add(pred, 1)
pred_reshape = tf.multiply(pred_reshape, mask)
pred_reshape = tf.add(pred_reshape, -1)

您将获得具有-1值的蒙版元素。