Tensorflow中的等效但可区分的argmax表达式

时间:2018-07-06 22:07:25

标签: python tensorflow

我需要一个热表示来表示张量中的最大值。

例如,考虑张量2 x 3

[ [1, 5, 2],
  [0, 3, 7] ]

我想要的 one-hot-argmax 表示形式如下:

[ [0, 1, 0],
  [0, 0, 1] ]

我可以按照以下步骤进行操作,其中my_tensorN x 3张量:

position = tf.argmax(my_tensor, axis=1).      # Shape (N x )
one_hot_pos = tf.one_hot(position, depth=3)   # Shape (N x 3)

但是由于我正在对其进行训练,因此这部分代码需要区分。 我的解决方法如下,其中EPSILON = 1e-3是一个小常数:

max_value = tf.reduce_max(my_tensor, axis=1, keepdims=True)
clip_min = max_value - EPSILON
one_hot_pos = (tf.clip_by_value(my_tensor, clip_min, max_value) - clip_min) / (max_value - clip_min)

该解决方法大部分时间都有效,但是-如预期的那样-它存在一些问题:

  • EPSILON敏感:如果太小,可能会被零除
  • 无法解决平局:argmax即使在平局中也只能选择一个。

在解决上述两个问题,但仅使用可区分的Tensorflow函数时,您是否知道模拟argmaxone_hot情况的更好方法?

1 个答案:

答案 0 :(得分:0)

执行一些最大值,平铺和乘法运算。喜欢:

a = tf.Variable([ [1, 5, 2], [0, 3, 7] ]) # your tensor

m = tf.reduce_max(a, axis=1) #  [5,7]
m = tf.expand_dims(m, -1)    # [[5],[7]]
m = tf.tile(m, [1,3])        # [[5,5,5],[7,7,7]]

y = tf.cast(tf.equal(a,m), tf.float32)) # [[0,1,0],[0,0,1]]

这是一个棘手的乘法运算,可微分。