基于概率分布将张量流数据集拼接在一起

时间:2017-09-27 13:57:04

标签: python python-3.x tensorflow tensorflow-datasets

我有一个按路径长度键入的路径数据集。网络需要采用相同长度的路径并批量处理它们。我正在寻找一种方法(使用Dataset API)根据路径长度分布选择路径长度 (就像P(length) = (number of paths for this length) / (total number of paths)一样简单,然后获取一批路径这个长度。我打算写tfrecord文件(我不认为我可以使用直接存储这些路径的python字典)在不同的目录中(每个长度)。如何使用数据集API构建返回给我的队列batch_size来自每个目录的tfrecords - 但是根据路径长度分布选择下一个目录来返回?所以我想我可以为每个路径长度构建一个数据集,如:

datasets = {}
for length in path_dict.keys():
    paths_dir = '/../tfrecords/%d' % length
    filenames = [os.path.join(paths_dir, x) for x in os.listdir(paths_dir)]
    datasets[length] = tf.contrib.data.TFRecordDataset(filenames).batch(batch_size)

但是如何从根据长度分布选择的长度为x的数据集中选择下一个元素的数据集构建新数据集?

(编辑:我可能错误地做了这个,例如我可以在每个长度上有一个tfrecord或者其他东西,感谢反馈)

EDIT2:在旧的API中,我可以做类似的事情:

def _random_filenames(tfrecords_dir, batch_size, examples):
    paths = OrderedDict()
    # tfrecords_dirs names is just the length of paths they contain
    tfrecords_dirs = sorted(map(int, os.listdir(tfrecords_dir)))
    for d in tfrecords_dirs:
        dir_ = os.path.join(tfrecords_dir, str(d))
        files = os.listdir(dir_)
        paths[d] = sorted(os.path.join(dir_, f) for f in files)
    # Calculate path statistics
    total_paths = sum(len(v) for v in paths.values())
    print('Total paths %d' % total_paths)
    stats = [(len(v) / total_paths) for v in paths.values()]
    lengths = list(paths.keys())
    filenames = []
    for j in range(0, examples, batch_size):
        pick = int(np.random.choice(lengths, p=stats))
        filenames.extend(list(np.random.choice(paths[pick], size=batch_size)))
    return filenames

然后从文件名中生成一个队列:

def _get_filename_queue(filenames):
    # Create a queue that produces the filenames to read.
    filename_queue = tf.train.string_input_producer(filenames,
        # epochs makes no sense here as we will just repeat the same filenames
        # shuffle even less so as we need batches of equal length
        num_epochs=None, shuffle=False, )
    return filename_queue

我正在寻找一种方法将其转换为新的数据集API

0 个答案:

没有答案