usinf TF Dataset API大幅降低的原因

时间:2019-02-04 18:51:14

标签: python performance tensorflow dataset

我正在尝试为三重态损失生成批次,其中批次中总是有成对的。下面的代码实现了这一点,但是它非常非常慢。特别是,choose_from_datasets方法似乎是缓慢的原因。

我的代码是否会导致速度变慢?还是有更聪明的方法来做到这一点?

我尝试改用sample_from_datasets,但这无济于事。

def batch_pairs3(dataset, num_classes, shuffle=True, num_classes_per_batch=10, num_images_per_class=2):
    # Isolate each class into its own dataset
    datasets = []
    for cl in range(num_classes):
        this_dataset = dataset.filter(lambda xx, yy: tf.equal(tf.reshape(yy, []), cl))
        if shuffle:
            this_dataset = this_dataset.shuffle(100)
        datasets += [this_dataset]

    # if shuffle:
    #     random.shuffle(datasets)
    selector = tf.contrib.data.Counter().map(
        lambda  x: generator3(x, num_classes, num_classes_per_batch, num_images_per_class))
    selector = selector.apply(tf.contrib.data.unbatch())
    dataset = tf.contrib.data.choose_from_datasets(datasets, selector)

    # Batch
    batch_size = num_classes_per_batch * num_images_per_class
    return dataset.batch(batch_size)

1 个答案:

答案 0 :(得分:0)

tf数据管道不能很好地迭代处理正在处理数据的此类应用程序,除非您可以独立地映射每个数据点来进行此类处理。对于您正在做的事情,最好以tfrecord格式进行数据的预处理和存储,然后使用数据管道以优化的方式读取数据。

请参考以下官方示例,该示例可解决涉及三重态损失的类似问题:Time Contrastive Networksthe data provider