Numpy的nanmax在Tensorflow中的功能?

时间:2016-08-17 20:26:58

标签: tensorflow

我有一个大小为X的数组[?, n](第一个维度是批量大小)。 我想得到k个最大数字(对于批处理中的每个向量),其中一些是nan。在通常的numpy中我会使用函数nanmax,但它似乎不存在于Tensorflow中。

我尝试使用nn.top_k使用完整的向量(使用nans),但它似乎无法正常工作 - 它会返回一个甚至有nan个值的向量虽然有足够的数字来填充输入的k

感谢。

1 个答案:

答案 0 :(得分:1)

您可以用负无穷大替换所有nan值并使用top_k

tf.reset_default_graph()
a = tf.constant([0, 1, 2, 3], dtype=tf.float32)
b = tf.constant([0, 1, 2, 3], dtype=tf.float32)
c = a/b
sess = tf.InteractiveSession()
print sess.run(tf.nn.top_k(c)[0])
infinities = tf.constant(-np.inf, shape=(4,))
c_fixed = tf.select(tf.is_nan(c), infinities, c)
print sess.run(tf.nn.top_k(c_fixed)[0])

这给出了

[ nan]
[ 1.]