将2GB以上的数据传递给tf.estimator

时间:2018-12-07 22:02:35

标签: python tensorflow tensorflow-datasets tensorflow-estimator

我有standalone.confx_train的numpy数组,每个数组都大于2GB。我想使用tf.estimator API训练模型,但出现错误:

y_train

我正在使用以下数据传递数据

ValueError: Cannot create a tensor proto whose content is larger than 2GB

tf.data文档mentions this error,并使用带有占位符的传统TenforFlow API提供解决方案。不幸的是,我不知道如何将其转换为tf.estimator API?

1 个答案:

答案 0 :(得分:0)

最适合我的解决方案是使用

tf.estimator.inputs.numpy_input_fn(x_train, y_train, num_epochs=EPOCHS,
                                   batch_size=BATCH_SIZE, shuffle=True)

而不是input_fn。唯一的问题是tf.estimator.inputs.numpy_input_fn会发出弃用警告,因此很遗憾,这也会停止工作。