我在reader = TFRecordReader()
方法中使用num_epochs > 1
和_filename_queue
来读取我的序列化tfrecord文件并使用tf.parse_single_sequence_example
进行解析,但我的批处理仅在一个时期后就结束了。代码似乎没有抛出任何错误,因为它很好地完成了我的第一批处理。
我有一个class TFRecord
使用以下_read
方法来读取
def _read(self, tf_record_path, num_epochs=None, tf_record_compression=None):
"""
:param tf_record_path: tf_record path
:param num_epochs: number of epochs
:param tf_record_compression: compression type or bool
:return:
"""
assert tf_record_compression in \
(True, False, None, 'GZIP', 'ZLIB', self.tf.python_io.TFRecordCompressionType.GZIP,
self.tf.python_io.TFRecordCompressionType.ZLIB, self.tf.python_io.TFRecordCompressionType.NONE)
if tf_record_compression in (False, None, self.tf.python_io.TFRecordCompressionType.NONE):
options = None
elif tf_record_compression in (True, 'GZIP', self.tf.python_io.TFRecordCompressionType.GZIP):
options = self.tf.python_io.TFRecordOptions(self.tf.python_io.TFRecordCompressionType.GZIP)
else:
assert tf_record_compression in ('ZLIB', self.tf.python_io.TFRecordCompressionType.ZLIB)
options = self.tf.python_io.TFRecordOptions(self.tf.python_io.TFRecordCompressionType.ZLIB)
with self.tf.variable_scope('Read'):
reader = tf.TFRecordReader(options=options)
filename_queue = self._filename_queue(tf_path=tf_record_path, num_epochs=num_epochs)
_, serialized_output = reader.read(filename_queue)
print("TFRecord Data Serialized!")
return serialized_output
def _filename_queue(self, tf_path, num_epochs=None):
"""
:param tf_path:
:param num_epochs:
:return:
"""
if self._tf_path is None:
if isinstance(tf_path, str):
self._tf_path = [tf_path]
else:
self._tf_path = tf_path
if num_epochs is None:
num_epochs = 1
else:
num_epochs = num_epochs
print('file_name path : ', self._tf_path)
print('The Number of Epochs is : ', num_epochs)
file_name_queue = self.tf.train.string_input_producer(
self._tf_path, num_epochs=num_epochs, name='FileNameQueue')
# self.tf.add_to_collection()
print(type(file_name_queue).__name__)
return file_name_queue
我仍然不知道为什么我的批处理在1个纪元后就结束了!