如何从tensorflow数据集中选择特定的列?

时间:2019-11-25 14:41:33

标签: python tensorflow tensorflow-datasets

我正在使用tf.data.Dataset预处理的CSV文件中的数据训练Tensorflow模型。但是,我希望模型分叉到与一组不同的csv列相对应的三个分支,并且model.fit需要为每个输出单独的数据集。 CSV文件的所有列都需要进行相同的预处理,因此最有效的准备方式是加载整个文件,对其进行处理,然后将数据集分为三部分。但是,我正在努力寻找一种方法。

我希望dataset.map允许我使用以下操作选择一些列:

dset = dset.map(lambda x: x[[1, 2, 3, 7]])

但似乎tensorflow将其解释为x[1][2][3][7]

我发现创建单独数据集的唯一可行方法是从一开始就这样做:

y = []
for cls, keys in output_classes.items():
    tmp = tf.data.experimental.CsvDataset(data_path, [tf.int32 for i in keys], select_cols=keys)
    [...]
    y.append(tmp)
y = tf.data.Dataset.zip(tuple(y))

不幸的是,它产生了很多不必要的开销,并且极大地减慢了训练速度。

是否可以通过功能的子集拆分tf​​.data.Dataset对象?

2 个答案:

答案 0 :(得分:1)

通过使用 .map() 修改 tornikeo 的答案,此解决方案对我有用。

dataset = tf.data.Dataset.from_tensor_slices([[1,2,3,4], 
                                              [5,6,7,8]])
dataset_filter = dataset.map(lambda x: tf.gather(x, [0, 2], axis=0))
result = list(dataset_filter.as_numpy_iterator())
print(result)

# Outputs array([1, 3], dtype=int32), array([5, 7])

答案 1 :(得分:0)

尝试tf.gather


tf.gather(tf.constant([1,2,3,4]), [1,2,3])
# ouputs : array([2, 3, 4])

如果您有高维数据,请使用tf.gather_nd