通过tf.data.Dataset将大型numpy数组馈入TensorFlow估算器

时间:2019-02-27 03:58:30

标签: python arrays numpy tensorflow tensorflow-estimator

TensorFlow的tf.data.Dataset documentation on consuming numpy arrays指出,要结合使用numpy数组和Dataset API,该数组必须足够小(总计<2 GB)才能用作张量,否则它们可以通过占位符输入到数据集中。

但是,如果将Dataset与估计量(在不存在占位符的情况下)结合使用,则文档不提供使用无占位符的大型数组的解决方案。

是否还有其他选项可以将占位符值传递到估算器中?或者是以tfrecordcsv格式提供数据的解决方案?

1 个答案:

答案 0 :(得分:0)

您可以在创建数据集对象之前使用np.splitfrom_generator

chunks = list(np.split(array, 1000))

def gen():
    for i in chunks:
        yield i

dataset = tf.data.Dataset.from_generator(gen, tf.float32)
dataset = dataset.shuffle(shuffle_buffer_size)
...

您可以使用随机播放控制数据集的大小。一次仅加载指定数量。