tf.string_input_producer即使设置了大于1的历元也给出了一个历元

时间:2018-10-24 10:25:49

标签: python tensorflow tfrecord

我的输入api使用tf.string_input_producer以及tf.parse_single_sequence_example。当我在tf.string_input_producer中将num_epochs> 1设置为1时,我的队列仍然结束。

这是预期的行为还是我做错了? 这是相关代码:

class TFRecordReader():



    def __init__(self):
        #some code....

    def execute_queue(self, tensor_queue, exception_message: str, log_dir_path=None):
        import os
        if log_dir_path is None:
            path = os.path.abspath('../../audio_log_dir/')
        else:
            path = log_dir_path
        writer = self._summary_file_writer(path)
        coord, thread = self._coord_thread()
        print('should_stop: ', coord.should_stop())

        if not coord.should_stop():
            try:
                if self._data_v is None:
                    self._data_v = self._parse_tensor(tensor_queue)

                 return self._data_v
            except self.tf.errors.OutOfRangeError:
                print(exception_message)
            finally:
                coord.request_stop()
                coord.join(thread)
                writer.close()


    def single_sequence_batch(self,
                              tf_record_path,
                              feature_map,
                              parse_function,
                              num_epochs=None,
                              tf_record_compression=None,
                              queue_completion_message='Data Exhausted!',
                              log_dir_path=None
                              ):
        self.feature_map = feature_map
        self.parse_func = parse_function
        batch = self._single_sequence_batch(tf_record_path=tf_record_path,
                                        num_epochs=num_epochs,
                                        tf_record_compression=tf_record_compression)
        data_queue = self.execute_queue(batch, queue_completion_message, log_dir_path=log_dir_path)
        return data_queue

def _test_single_sequence_batch(num_epochs=1):
    tfr_path = r'C:/audio_tfrecord/audioapi.tfrecord'
    reader = TFRecordReader()
    data_q = reader.single_sequence_batch(tf_record_path=tfr_path,
                                      feature_map=feature_mapping,
                                      parse_function=parse_func,
                                      tf_record_compression=True,
                                      num_epochs=num_epochs)
    print(len(data_q))
    c = 0
    try:
        for i in range(num_epochs):
            val = reader.session.run(data_q)
            print(val)
            c += 1
    except tf.errors.OutOfRangeError:
        print("Total Examples :", c)
        print('Finished!')

0 个答案:

没有答案