访问tf.data.Dataset的索引以删除和附加数据元素

时间:2019-08-19 03:07:03

标签: tensorflow keras

我目前正在从事一个研究项目,在这个项目中,我必须在某些时期之后追加训练集,并在评估后从测试集中删除一些样本。当前,没有办法可以访问tfrecord文件中的记录(放置在特定索引处)以进行删除或追加。由于tfrecords提供了非常快速的培训,因此我避免使用生成器。有任何建议以这种方式访问​​tfrecord文件吗?

1 个答案:

答案 0 :(得分:1)

  

在某些时期之后,我必须追加训练集

您可以使用repeat(n)重复n个时期的数据集,然后使用concatenate(new_dataset)附加一些额外的数据。因此,例如,为了在15个纪元后追加新数据,我们可以这样做:

dataset = tf.data.TFRecordDataset('filepath.tfrecord')
new_data = tf.data.TFRecordDataset('filepath_of_records_to_append.tfrecord') # or any other dataset from generator or whatever!

dataset = dataset.repeat(15).concatenate(new_data)
  

并从测试集中删除一些样本

您无法使用tf.data API轻松地从原始tfrecord文件中删除(您实际上必须编写一个新的tfrecord文件,并省略记录),所以也许相反,您发现仅构造一个tf.data数据集并忽略还是跳过您要省略的记录?这更加简单,可以使用take()skip()来完成。

例如,如果我想跳过索引4、7、8、9和10的记录,则可以执行以下操作:

dataset = tf.data.TFRecordDataset('filepath.tfrecord')

dataset = dataset.take(4).skip(1).take(2).skip(4)