我有一个大小为X
的数组[?, n]
(第一个维度是批量大小)。
我想得到k
个最大数字(对于批处理中的每个向量),其中一些是nan
。在通常的numpy中我会使用函数nanmax
,但它似乎不存在于Tensorflow中。
我尝试使用nn.top_k
使用完整的向量(使用nans
),但它似乎无法正常工作 - 它会返回一个甚至有nan
个值的向量虽然有足够的数字来填充输入的k
。
感谢。
答案 0 :(得分:1)
您可以用负无穷大替换所有nan
值并使用top_k
tf.reset_default_graph()
a = tf.constant([0, 1, 2, 3], dtype=tf.float32)
b = tf.constant([0, 1, 2, 3], dtype=tf.float32)
c = a/b
sess = tf.InteractiveSession()
print sess.run(tf.nn.top_k(c)[0])
infinities = tf.constant(-np.inf, shape=(4,))
c_fixed = tf.select(tf.is_nan(c), infinities, c)
print sess.run(tf.nn.top_k(c_fixed)[0])
这给出了
[ nan]
[ 1.]