我使用tf.data.Dataset.from_generator
从生成器加载数据流,然后扩充数据并连接原始数据和扩充数据以扩大数据集。每次串联调用都会调用生成器,如何避免呢?
我编写了以下代码:
import tensorflow as tf
tf.enable_eager_execution()
def generator():
# some heavy fn to load data points form stream (generator)
print("CALL generator")
for i in range(5): # simulation of data stream
print("generator iteration: ", str(i))
yield i
if __name__ == '__main__':
data = tf.data.Dataset.from_generator(generator, ( tf.int32))
dataset = data
# augment data points and add augmented version to dataset to have both original data points and augmented data
# points
data = data.concatenate(
dataset.map(lambda x: x * 2))
iterator = data.batch(10).prefetch(10).make_one_shot_iterator()
for im in iterator:
print(im)
我希望:
CALL generator
generator iteration: 0
generator iteration: 1
generator iteration: 2
generator iteration: 3
generator iteration: 4
tf.Tensor([0 1 2 3 4 0 2 4 6 8], shape=(10,), dtype=int32)
但是我得到了:
CALL generator
generator iteration: 0
generator iteration: 1
generator iteration: 2
generator iteration: 3
generator iteration: 4
CALL generator
generator iteration: 0
generator iteration: 1
generator iteration: 2
generator iteration: 3
generator iteration: 4
tf.Tensor([0 1 2 3 4 0 2 4 6 8], shape=(10,), dtype=int32)