我正在尝试为三重态损失生成批次,其中批次中总是有成对的。下面的代码实现了这一点,但是它非常非常慢。特别是,choose_from_datasets方法似乎是缓慢的原因。
我的代码是否会导致速度变慢?还是有更聪明的方法来做到这一点?
我尝试改用sample_from_datasets,但这无济于事。
def batch_pairs3(dataset, num_classes, shuffle=True, num_classes_per_batch=10, num_images_per_class=2):
# Isolate each class into its own dataset
datasets = []
for cl in range(num_classes):
this_dataset = dataset.filter(lambda xx, yy: tf.equal(tf.reshape(yy, []), cl))
if shuffle:
this_dataset = this_dataset.shuffle(100)
datasets += [this_dataset]
# if shuffle:
# random.shuffle(datasets)
selector = tf.contrib.data.Counter().map(
lambda x: generator3(x, num_classes, num_classes_per_batch, num_images_per_class))
selector = selector.apply(tf.contrib.data.unbatch())
dataset = tf.contrib.data.choose_from_datasets(datasets, selector)
# Batch
batch_size = num_classes_per_batch * num_images_per_class
return dataset.batch(batch_size)
答案 0 :(得分:0)
tf数据管道不能很好地迭代处理正在处理数据的此类应用程序,除非您可以独立地映射每个数据点来进行此类处理。对于您正在做的事情,最好以tfrecord格式进行数据的预处理和存储,然后使用数据管道以优化的方式读取数据。
请参考以下官方示例,该示例可解决涉及三重态损失的类似问题:Time Contrastive Networks,the data provider