通过Tensorflow消耗大数据

时间:2018-10-24 15:51:42

标签: python tensorflow deep-learning bigdata

假设我有大量的对象,例如,每个对象可以是一个numpy数组的列表。

将此数据集传递到张量流的最佳方法是什么?

我希望能够随机整理数据并形成批处理。可能值得使用标准的python(numpy)过程重新整理数据集并形成批次,然后再使用类似tf.data.Dataset.from_generator()的方法吗?

由于tf.Tensor协议缓冲区的大小限制(根据Tensorflow文档),直接将完整数据集转换为tf.GraphDef的方法似乎没有用。

1 个答案:

答案 0 :(得分:0)

您的数据看起来很大,但仍然足够小以适合内存吗?如果是这样,那么您在tf.data.Dataset.from_generator()上的位置正确。然后,您可以像这样

import itertools

# your data
data = range(1024)
def gen():
  for item in data:
    yield data

ds = Dataset.from_generator(
    gen, tf.int64, tf.TensorShape([])).shuffle(buffer_size=128).batch(batch_size=4)
value = ds.make_one_shot_iterator().get_next()

sess.run(value)  # array([0, 1, 2, 3])

或者,您可以将数据转储到TFRecord文件并使用TFRecordDataset从中读取。 test应该可以帮助您入门。