生成器慢的张量流数据集

时间:2019-03-18 09:54:14

标签: python tensorflow tensorflow-datasets

我有一个数据文件,格式如下: A B C x,y,z

因为A B C非常大,所以我不想将其转换为以下格式的磁盘

  1. A B C x
  2. A B C y
  3. A B C z

现在我正在使用dataset.from_generator来读取它,但是训练时它非常慢。

我该如何解决?

logis是:

def gen_sample2(self):
    for file_name in self.file_names:

        with open(file_name) as f:
            for line in f:
                item = line.strip().split(',')
                pids_str = item[4]
                pids = [int(pid) for pid in pids_str.split('|') if pid in self.pid_set]
                for item_id in pids:
                    yield {"gender": item[1], "topics": item[2], "weights": item[3], "item_id": item_id}

def get_dataset_iterator(self):
    dataset = tf.data.Dataset.from_generator(self.gen_sample2,
                                             output_types={"gender":tf.string,
                                                           "topics":tf.string,
                                                           "weights":tf.string,
                                                           "item_id":tf.int32
                                                           })
    dataset = dataset.batch(self.batch_size)

    # if self.buffer_size > 0:
    #     dataset = dataset.shuffle(self.buffer_size)

    dataset = dataset.repeat(self.iter)
    return  dataset.make_one_shot_iterator()

0 个答案:

没有答案