在Tensorflow中过滤数据

时间:2019-10-10 22:20:26

标签: python tensorflow tensorflow-datasets tensorflow2.0

我在TensorFlow中有一些列数据,我想对其中一列进行过滤,如下所示:

const categories = [
    { category: 'Patch Leads', solution: 'Data Solutions' },
    { category: 'Cables', solution: 'Data Solutions' },

    { category: 'Nails', solution: 'Hardware' },
    { category: 'Locks', solution: 'Hardware' },
    { category: 'Screws', solution: 'Hardware' },

    { category: 'Cabinets', solution: 'Cabinet Solutions' },
    { category: 'Swing Frames', solution: 'Cabinet Solutions' },
    { category: 'Racks', solution: 'Cabinet Solutions' },

    { category: 'Fire Cables', solution: 'Fire Solutions' },
];

class category{
    constructor(id,name){
        this.id = id;
        this.name = name;
        this.slug = name;
    }
}
class NewOne {
    constructor(id,name,categories=[]) {
        this.id = id;
        this.name = name;
        this.categories = categories;
    }
}
let solutions = [];

solutions.push(new NewOne(0, categories[0].solution,[new category(0,categories[0].category)]));

let newArrayIndex = 0;
let idPlusOne = 1;

for(index in categories){
    if(solutions[newArrayIndex].name !== categories[index].solution){
        solutions.push(new NewOne(index, categories[index].solution,[new category(0,categories[index].category)]));
        newArrayIndex++;
        idPlusOne=1;
    }else{
        solutions[newArrayIndex].categories.push(new category(idPlusOne,categories[index].category));
        idPlusOne++;
    }
}

但这会产生错误消息:

  

ValueError:import pandas as pd import tensorflow.compat.v2 as tf import tensorflow.compat.v1 as tfv1 tfv1.enable_v2_behavior() csv_file = tf.keras.utils.get_file('heart.csv', 'https://storage.googleapis.com/applied-dl/heart.csv') df = pd.read_csv(csv_file) target = df.pop('target') df['thal'] = pd.Categorical(df['thal']) df['thal'] = df.thal.cat.codes # Use interleave() and prefetch() to read many files concurrently. #files = tf.data.Dataset.list_files(file_pattern=input_file_pattern, shuffle=True, seed=123456789) #dataset = files.interleave(lambda x: tf.data.RecordIODataset(x).prefetch(100), cycle_length=8) #Pretend I actually had some data files dataset = tf.data.Dataset.from_tensor_slices((df.to_dict('list'), target.values)) dataset = dataset.shuffle(1000, seed=123456789) dataset = dataset.batch(20) #Pretend I did some parsing here # dataset = dataset.map(parse_record, num_parallel_calls=20) dataset = dataset.filter(lambda x, label: x['trestbps']<135) 返回类型必须可转换为标量布尔张量。是predicate

我该怎么过滤数据?

1 个答案:

答案 0 :(得分:1)

这是因为您在filter之后加上了batch。 因此,在lambda表达式中,x是形状为(None,)的批处理(将drop_reminder=True传递到batch以获得(20,)的形状),而不是样本。要解决此问题,您必须在filter之前致电batch

有一种解决方案,可以在batch之后使用map进行“过滤”。但是,正如您所看到的,这具有使批量变量大小变大的副作用:您在输入中获得了20个批次,并且删除了不符合特定条件的元素(trestbps <135),而没有从中删除相同数量的元素每批。而且,此解决方案的效果非常差...

import timeit

import pandas as pd

import tensorflow.compat.v2 as tf
import tensorflow.compat.v1 as tfv1
tfv1.enable_v2_behavior()

def s1(ds):
    dataset = ds
    dataset = dataset.filter(lambda x, label: x['trestbps']<135)
    dataset = dataset.batch(20)
    return dataset

def s2(ds):
    dataset = ds
    dataset = dataset.batch(20)
    dataset = dataset.map(lambda x, label: (tf.nest.map_structure(lambda y: y[x['trestbps'] < 135], x), label[x['trestbps'] < 135]))
    return dataset


def base_ds():
    csv_file = tf.keras.utils.get_file('heart.csv', 'https://storage.googleapis.com/applied-dl/heart.csv')

    df = pd.read_csv(csv_file)
    target = df.pop('target')
    df['thal'] = pd.Categorical(df['thal'])
    df['thal'] = df.thal.cat.codes

    return tf.data.Dataset.from_tensor_slices((df.to_dict('list'), target.values))


def main():
    ds = base_ds()
    ds1 = s1(ds)
    ds2 = s2(ds)
    tf.print("DS_S1:", [tf.nest.map_structure(lambda x: x.shape, x) for x in ds1])
    tf.print("DS_S2:", [tf.nest.map_structure(lambda x: x.shape, x) for x in ds2])
    tf.print("Are equals?", [x for x in ds1] == [x for x in ds2])
    tf.print("Contains same elements?", [x for x in ds1.unbatch()] == [x for x in ds2.unbatch()])

    tf.print("Filter and batch:", timeit.timeit(lambda: s1(ds), number=100))
    tf.print("Batch and map:", timeit.timeit(lambda: s2(ds), number=100))

if __name__ == '__main__':
    main()

结果:

# Tensor shapes
[...]
Are equals? False
Contains same elements? True
Filter and batch: 0.5571189750007761
Batch and map: 15.582061060000342

种类