如何将增强数据连接到tf.data

时间:2019-05-18 15:22:27

标签: python tensorflow concatenation generator

我使用tf.data.Dataset.from_generator从生成器加载数据流,然后扩充数据并连接原始数据和扩充数据以扩大数据集。每次串联调用都会调用生成器,如何避免呢?

我编写了以下代码:

import tensorflow as tf

tf.enable_eager_execution()


def generator():
    # some heavy fn to load data points form stream (generator)
    print("CALL generator")
    for i in range(5): # simulation of data stream
        print("generator iteration: ", str(i))
        yield i


if __name__ == '__main__':
    data = tf.data.Dataset.from_generator(generator, ( tf.int32))
    dataset = data
    # augment data points and add augmented version to dataset to have both original data points and augmented data
    # points
    data = data.concatenate(
        dataset.map(lambda x: x * 2))
    iterator = data.batch(10).prefetch(10).make_one_shot_iterator()
    for im in iterator:
        print(im)

我希望:

CALL generator
generator iteration:  0
generator iteration:  1
generator iteration:  2
generator iteration:  3
generator iteration:  4
tf.Tensor([0 1 2 3 4 0 2 4 6 8], shape=(10,), dtype=int32)

但是我得到了:

CALL generator  
generator iteration:  0  
generator iteration:  1  
generator iteration:  2  
generator iteration:  3  
generator iteration:  4  

CALL generator  
generator iteration:  0 
generator iteration:  1  
generator iteration:  2  
generator iteration:  3  
generator iteration:  4  
tf.Tensor([0 1 2 3 4 0 2 4 6 8], shape=(10,), dtype=int32) 

0 个答案:

没有答案