如何使用Dataset API一次性读取TFRecord文件中的所有Example记录?

时间:2017-12-20 20:33:15

标签: tensorflow tensorflow-datasets

我正在使用Spark作业生成一个$json_array = []; $json_array[] = 4; $json_array[] = $first_array; $json_array[] = $second_array; return json_encode($json_array); 文件,该文件将是我的(word,count)对的词汇文件。

我想使用数据集API一次加载整个文件,因为我的词汇表文件可以在HDFS上,并且可能分成多个物理文件。那就是说,我发现它非常不直观。到目前为止,这是我的代码:

TFRecord

使用巨大的,固定的前期批量大小是我能够想到在def parse(example): parsed = tf.parse_single_example(example, features={ 'token': tf.FixedLenFeature([], dtype=tf.string), 'count': tf.FixedLenFeature([], dtype=tf.int64) }) return parsed['token'], parsed['count'] filenames = tf.gfile.Glob(filenames) dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(parse) dataset = dataset.batch(MAX_VOCAB_FILE) iterator = dataset.make_one_shot_iterator() token, token_count = iterator.get_next() 的张量中一次性填充所有数据的唯一方法。它看起来也很慢。

有更好的方法吗?

1 个答案:

答案 0 :(得分:0)

我必须在某一点做类似的事情。这不直观,但是Importing Data的TF文档中有一个答案。

为了让我的解决方案更易于理解,我假设您的代码包含在名为get_batch的方法中:

def get_batch(filenames, batch_size):
    # your code
    return token, token_count

batch_size参数替换示例中的MAX_VOCAB_FILE。如果你想将所有连续的(token, token_count)对打印到stdout,一次一行,你可以这样做:

with tf.Session() as sess:
    example = get_batch(filenames, 1)
    while True:
        print(sess.run(example))

然后迭代器耗尽,循环以OutOfRangeError退出。发出的每条记录都是一对张量为零的物体。

希望这是足够的细节来做一些有用的事情。