我有一个2D张量,我想只将函数应用于每行的最大元素。例如,如果我的张量是:
[[0.5, 0.7],
[0.4, 0.3]]
我的函数是乘以2,然后我希望我的结果是
[[0.5, 1.4],
[0.8, 0.3]].
我目前的做法是使用tf.select
,如下所示:
tf.select(tf.equal(my_tensor,
tf.reduce_max(my_tensor,axis=1)),
my_function(my_tensor),
my_tensor)
我们的想法是使用tf.reduce_max创建一个张量,它是每行的最大值,然后使用tf.equal创建一个布尔张量,表示输入的每个元素是否等于其行max。然后将其作为掩码传递给tf.select,以确定是否应该应用该函数。但是,这似乎不起作用 - 最大张量与输入张量的维度不同,我认为广播存在一些问题。此外,如果存在两个相等的最大元素,则此方法可能会遇到问题。我宁愿使用一种自动断开关系的方法,也许使用tf.arg_max。
答案 0 :(得分:2)
我能够通过在tf.reduce_max的结果上使用tf.expand_dims来修改问题中的代码,以获得要广播的正确形状的张量以匹配输入张量。结果如下:
tf.select(tf.equal(my_tensor,
tf.expand_dims(tf.reduce_max(my_tensor,axis=1), axis=1)),
my_function(my_tensor),
my_tensor)
虽然这有效,但我怀疑有一种更清洁的方法。