假设我有大量的对象,例如,每个对象可以是一个numpy数组的列表。
将此数据集传递到张量流的最佳方法是什么?
我希望能够随机整理数据并形成批处理。可能值得使用标准的python(numpy)过程重新整理数据集并形成批次,然后再使用类似tf.data.Dataset.from_generator()
的方法吗?
由于tf.Tensor
协议缓冲区的大小限制(根据Tensorflow文档),直接将完整数据集转换为tf.GraphDef
的方法似乎没有用。
答案 0 :(得分:0)
您的数据看起来很大,但仍然足够小以适合内存吗?如果是这样,那么您在tf.data.Dataset.from_generator()上的位置正确。然后,您可以像这样
import itertools
# your data
data = range(1024)
def gen():
for item in data:
yield data
ds = Dataset.from_generator(
gen, tf.int64, tf.TensorShape([])).shuffle(buffer_size=128).batch(batch_size=4)
value = ds.make_one_shot_iterator().get_next()
sess.run(value) # array([0, 1, 2, 3])
或者,您可以将数据转储到TFRecord文件并使用TFRecordDataset从中读取。 test应该可以帮助您入门。