在TensorFlow中,如何有选择地仅将函数应用于张量的每一行的最大元素?

时间:2017-02-24 20:45:24

标签: python tensorflow

我有一个2D张量,我想只将函数应用于每行的最大元素。例如,如果我的张量是:

[[0.5, 0.7],
 [0.4, 0.3]]

我的函数是乘以2,然后我希望我的结果是

[[0.5, 1.4],
 [0.8, 0.3]].

我目前的做法是使用tf.select,如下所示:

tf.select(tf.equal(my_tensor,
                   tf.reduce_max(my_tensor,axis=1)), 
          my_function(my_tensor),
          my_tensor)

我们的想法是使用tf.reduce_max创建一个张量,它是每行的最大值,然后使用tf.equal创建一个布尔张量,表示输入的每个元素是否等于其行max。然后将其作为掩码传递给tf.select,以确定是否应该应用该函数。但是,这似乎不起作用 - 最大张量与输入张量的维度不同,我认为广播存在一些问题。此外,如果存在两个相等的最大元素,则此方法可能会遇到问题。我宁愿使用一种自动断开关系的方法,也许使用tf.arg_max。

1 个答案:

答案 0 :(得分:2)

我能够通过在tf.reduce_max的结果上使用tf.expand_dims来修改问题中的代码,以获得要广播的正确形状的张量以匹配输入张量。结果如下:

tf.select(tf.equal(my_tensor,
                   tf.expand_dims(tf.reduce_max(my_tensor,axis=1), axis=1)),
          my_function(my_tensor),
          my_tensor)

虽然这有效,但我怀疑有一种更清洁的方法。