在Tensorflow数据集中过滤NaN值

时间:2020-10-01 14:23:33

标签: python tensorflow tensorflow2.0 tensorflow-datasets

是否有一种简便的方法可以从nan实例中过滤所有包含tensorflow.data.Dataset值的条目?像熊猫中的dropna方法一样?


简短示例:

import numpy as np
import tensorflow as tf

X = tf.data.Dataset.from_tensor_slices([[1,2,3], [0,0,0], [np.nan,np.nan,np.nan], [3,4,5], [np.nan,3,4]])
y = tf.data.Dataset.from_tensor_slices([np.nan, 0, 1, 2, 3])
ds = tf.data.Dataset.zip((X,y))
ds = foo(ds)  # foo(x) = ?
for x in iter(ds): print(str(x))

foo(x)可以用来获得以下输出:

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>)
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

如果您想自己尝试,here is Google Colab notebook

2 个答案:

答案 0 :(得分:2)

怎么样:

def any_nan(t):
    return tf.reduce_sum(
        tf.cast(
            tf.math.is_nan(t),
            tf.int32,
        )
    ) > tf.constant(0)


>>> ds_filtered = ds.filter(lambda x, y: not any_nan(x) and not any_nan(y))
>>> for x in iter(ds_filtered): print(str(x))
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>)
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

答案 1 :(得分:1)

我采用的方法与现有答案略有不同。我使用tf.reduce_any而不是使用总和:

filter_nan = lambda x, y: not tf.reduce_any(tf.math.is_nan(x)) and not tf.math.is_nan(y)

ds = tf.data.Dataset.zip((X,y)).filter(filter_nan)

list(ds.as_numpy_iterator())
[(array([0., 0., 0.], dtype=float32), 0.0),
 (array([3., 4., 5.], dtype=float32), 2.0)]