从Tensorflow中的.tfrecords文件中获取记录总数

时间:2016-11-07 18:33:46

标签: tensorflow tfrecord

是否可以从.tfrecords文件中获取记录总数?与此相关,人们如何通常跟踪训练模型时已经过的时期数?虽然我们可以指定batch_sizenum_of_epochs,但我不确定是否可以直接获取current epoch等值,每个时期的批次数等等 - 只是为了这样我可以更好地控制培训的进展情况。目前,我只是使用脏黑客来计算这个,因为我事先知道我的.tfrecords文件中有多少条记录以及我的小型机的大小。感谢任何帮助..

5 个答案:

答案 0 :(得分:28)

要计算记录数,您应该可以使用tf.python_io.tf_record_iterator

c = 0
for fn in tf_records_filenames:
  for record in tf.python_io.tf_record_iterator(fn):
     c += 1

为了跟踪模型训练,tensorboard派上用场。

答案 1 :(得分:17)

不,不可能。 TFRecord不存储有关存储在其中的数据的任何元数据。这个档案

  

表示一系列(二进制)字符串。格式不是随机的   访问,因此它适用于流式传输大量数据但不适用   适用于需要快速分片或其他非顺序访问的情况。

如果需要,您可以手动存储此元数据或使用record_iterator获取数字(您需要遍历所有记录:

sum(1 for _ in tf.python_io.tf_record_iterator(file_name))

如果你想知道当前的纪元,你可以从张量板或从循环中打印数字来做到这一点。

答案 2 :(得分:1)

随着tf.io.tf_record_iterator被弃用,萨尔瓦多·达利(Salvador Dali)伟大的answer现在应该阅读

tf.enable_eager_execution()
sum(1 for _ in tf.data.TFRecordDataset(file_name))

答案 3 :(得分:0)

根据tf_record_iterator的弃用警告,我们还可以使用急切执行来计数记录。

#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import sys

assert len(sys.argv) == 2, \
    "USAGE: {} <file_glob>".format(sys.argv[0])

tf.enable_eager_execution()

input_pattern = sys.argv[1]

# This is where we get the count of records
records_n = sum(1 for record in tf.data.TFRecordDataset(tf.gfile.Glob(input_pattern)))

print("records_n = {}".format(records_n))

答案 4 :(得分:0)

由于 tf.enable_eager_execution() 不再有效,请使用:

tf.compat.v1.enable_eager_execution

sum(1 for _ in tf.data.TFRecordDataset(FILENAMES))