假设我有一个大小为BxWxHxD的张量。我想处理张量,使得我有一个新的BxWxHxD张量,其中只保留每个WxH切片中的最大元素,并且所有其他值都为零。
换句话说,我认为实现这一目标的最佳方法是以某种方式在WxH切片上采用2D argmax,从而产生行和列的BxD索引张量,然后可以将其转换为单热BxWxHxD张量用作面具。我如何使这项工作?
答案 0 :(得分:2)
您可以使用以下功能作为起点。它计算每个批次和每个渠道的最大元素的索引。生成的数组采用格式(批量大小,2,通道数)。
def argmax_2d(tensor):
# input format: BxHxWxD
assert rank(tensor) == 4
# flatten the Tensor along the height and width axes
flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3]))
# argmax of the flat tensor
argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32)
# convert indexes into 2D coordinates
argmax_x = argmax // tf.shape(tensor)[2]
argmax_y = argmax % tf.shape(tensor)[2]
# stack and return 2D coordinates
return tf.stack((argmax_x, argmax_y), axis=1)
def rank(tensor):
# return the rank of a Tensor
return len(tensor.get_shape())