使用程序实时数据生成训练tensorflow数据API

时间:2018-11-06 00:05:33

标签: python tensorflow keras tensorflow-datasets

我的问题与this相似,因为我想即时生成一批训练数据。我有一个函数get_random_batch(batch_size, input_path, target_path, **other_kwargs),它返回inputstargets,但这是一个纯香草python / numpy函数,而不是tensorflow。因此它返回numpy数组,而不是张量。 (该函数是一个非常复杂的大型函数,带有一些第三方函式库。将其移植到tensorflow不太可行。由于数据量巨大,因此预处理所有内容也不可行,而且我没有空间来存储所有内容实际上,我什至不训练一个纪元,我只是随机选择数十万次迭代。

几年来,我一直在使用tensorflow的低级训练API:生成一批输入目标对,使用Feed进行正向传递(进入占位符)并获取,计算损失,应用渐变,重复等。

现在,我终于想尝试使用较新的数据API(以便可以使用Keras API来构建模型和“拟合”等),但是我不知道如何进行迁移。我所见过的所有文档都假定数据加载和预处理是图形的一部分,并且数据集的输出已经是张量。

-

更新:好的,这似乎可以用于 tf.data.Dataset.from_generator

compile_kwargs = dict(
    optimizer = tf.train.AdamOptimizer(3e-4),
    loss = 'mse',
    metrics = ['accuracy', 'mse']
)

def prepare_data():
    x,t = get_random_training_pair() # this returns one training pair (each np.float32 ndarrays)
    yield (x, t)

dataset = tf.data.Dataset.from_generator(prepare_data, (tf.float32, tf.float32))
dataset = dataset.batch(128).repeat()
model = build_keras_model()
model.compile(**compile_kwargs)
model.summary()
model.fit(dataset, epochs=10, steps_per_epoch=100, shuffle=False)

0 个答案:

没有答案