如何在使用估算器进行训练期间仅将批处理数据加载到内存中?

时间:2019-06-20 05:55:35

标签: tensorflow load conv-neural-network tensorflow-datasets tensorflow-estimator

目前,我使用以下代码使用估算器训练模型。但是,当我使用大型数据集时,我的内存(RAM)不足以加载大型数据集。因此,有没有一种方法可以在使用估计器进行训练时将批处理数据仅加载到内存中?

here示例显示的是keras。如何使用估算器实现它?

当前,我正在将所有数据加载到内存中,并将其提供给估算器。

classifier = tf.estimator.Estimator(model_fn = convNet,model_dir='/dir')
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": train_data},
        y=train_labels,
        batch_size=32,
        num_epochs=1,
        shuffle=True)
classifier.train(input_fn=train_input_fn, steps=657)

1 个答案:

答案 0 :(得分:0)

u可以使用tf.data.dataset。例如:tf.data.TextLineDataset或tf.data.TFRecordDataset。

然后在input_fn中处理batch_size,epoch,shuffle:

    def input_fn():

        dataset=tf.data.TextLineDataset(file_path)

        dataset=dataset.map(map_fn)
        dataset=dataset.shuffle(shuffle).repeat(epochs).batch(batch_size)
        return dataset