我的输入api使用tf.string_input_producer以及tf.parse_single_sequence_example。当我在tf.string_input_producer中将num_epochs> 1设置为1时,我的队列仍然结束。
这是预期的行为还是我做错了? 这是相关代码:
class TFRecordReader():
def __init__(self):
#some code....
def execute_queue(self, tensor_queue, exception_message: str, log_dir_path=None):
import os
if log_dir_path is None:
path = os.path.abspath('../../audio_log_dir/')
else:
path = log_dir_path
writer = self._summary_file_writer(path)
coord, thread = self._coord_thread()
print('should_stop: ', coord.should_stop())
if not coord.should_stop():
try:
if self._data_v is None:
self._data_v = self._parse_tensor(tensor_queue)
return self._data_v
except self.tf.errors.OutOfRangeError:
print(exception_message)
finally:
coord.request_stop()
coord.join(thread)
writer.close()
def single_sequence_batch(self,
tf_record_path,
feature_map,
parse_function,
num_epochs=None,
tf_record_compression=None,
queue_completion_message='Data Exhausted!',
log_dir_path=None
):
self.feature_map = feature_map
self.parse_func = parse_function
batch = self._single_sequence_batch(tf_record_path=tf_record_path,
num_epochs=num_epochs,
tf_record_compression=tf_record_compression)
data_queue = self.execute_queue(batch, queue_completion_message, log_dir_path=log_dir_path)
return data_queue
def _test_single_sequence_batch(num_epochs=1):
tfr_path = r'C:/audio_tfrecord/audioapi.tfrecord'
reader = TFRecordReader()
data_q = reader.single_sequence_batch(tf_record_path=tfr_path,
feature_map=feature_mapping,
parse_function=parse_func,
tf_record_compression=True,
num_epochs=num_epochs)
print(len(data_q))
c = 0
try:
for i in range(num_epochs):
val = reader.session.run(data_q)
print(val)
c += 1
except tf.errors.OutOfRangeError:
print("Total Examples :", c)
print('Finished!')