是否可以从.tfrecords
文件中获取记录总数?与此相关,人们如何通常跟踪训练模型时已经过的时期数?虽然我们可以指定batch_size
和num_of_epochs
,但我不确定是否可以直接获取current epoch
等值,每个时期的批次数等等 - 只是为了这样我可以更好地控制培训的进展情况。目前,我只是使用脏黑客来计算这个,因为我事先知道我的.tfrecords文件中有多少条记录以及我的小型机的大小。感谢任何帮助..
答案 0 :(得分:28)
要计算记录数,您应该可以使用tf.python_io.tf_record_iterator
。
c = 0
for fn in tf_records_filenames:
for record in tf.python_io.tf_record_iterator(fn):
c += 1
为了跟踪模型训练,tensorboard派上用场。
答案 1 :(得分:17)
不,不可能。 TFRecord不存储有关存储在其中的数据的任何元数据。这个档案
表示一系列(二进制)字符串。格式不是随机的 访问,因此它适用于流式传输大量数据但不适用 适用于需要快速分片或其他非顺序访问的情况。
如果需要,您可以手动存储此元数据或使用record_iterator获取数字(您需要遍历所有记录:
sum(1 for _ in tf.python_io.tf_record_iterator(file_name))
如果你想知道当前的纪元,你可以从张量板或从循环中打印数字来做到这一点。
答案 2 :(得分:1)
随着tf.io.tf_record_iterator被弃用,萨尔瓦多·达利(Salvador Dali)伟大的answer现在应该阅读
tf.enable_eager_execution()
sum(1 for _ in tf.data.TFRecordDataset(file_name))
答案 3 :(得分:0)
根据tf_record_iterator的弃用警告,我们还可以使用急切执行来计数记录。
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import sys
assert len(sys.argv) == 2, \
"USAGE: {} <file_glob>".format(sys.argv[0])
tf.enable_eager_execution()
input_pattern = sys.argv[1]
# This is where we get the count of records
records_n = sum(1 for record in tf.data.TFRecordDataset(tf.gfile.Glob(input_pattern)))
print("records_n = {}".format(records_n))
答案 4 :(得分:0)
由于 tf.enable_eager_execution() 不再有效,请使用:
tf.compat.v1.enable_eager_execution
sum(1 for _ in tf.data.TFRecordDataset(FILENAMES))