tensorflow.parse_single_sequence_example在单个时期后关闭

时间:2018-11-01 07:59:23

标签: python-3.x tensorflow machine-learning tfrecord

我在reader = TFRecordReader()方法中使用num_epochs > 1_filename_queue来读取我的序列化tfrecord文件并使用tf.parse_single_sequence_example进行解析,但我的批处理仅在一个时期后就结束了。代码似乎没有抛出任何错误,因为它很好地完成了我的第一批处理。 我有一个class TFRecord使用以下_read方法来读取

    def _read(self, tf_record_path, num_epochs=None, tf_record_compression=None):
    """
    :param tf_record_path: tf_record path
    :param num_epochs: number of epochs
    :param tf_record_compression: compression type or bool
    :return:
    """

    assert tf_record_compression in \
           (True, False, None, 'GZIP', 'ZLIB', self.tf.python_io.TFRecordCompressionType.GZIP,
            self.tf.python_io.TFRecordCompressionType.ZLIB, self.tf.python_io.TFRecordCompressionType.NONE)
    if tf_record_compression in (False, None, self.tf.python_io.TFRecordCompressionType.NONE):
        options = None

    elif tf_record_compression in (True, 'GZIP', self.tf.python_io.TFRecordCompressionType.GZIP):
        options = self.tf.python_io.TFRecordOptions(self.tf.python_io.TFRecordCompressionType.GZIP)

    else:
        assert tf_record_compression in ('ZLIB', self.tf.python_io.TFRecordCompressionType.ZLIB)
        options = self.tf.python_io.TFRecordOptions(self.tf.python_io.TFRecordCompressionType.ZLIB)
    with self.tf.variable_scope('Read'):
        reader = tf.TFRecordReader(options=options)

        filename_queue = self._filename_queue(tf_path=tf_record_path, num_epochs=num_epochs)
        _, serialized_output = reader.read(filename_queue)
        print("TFRecord Data Serialized!")
    return serialized_output

    def _filename_queue(self, tf_path, num_epochs=None):
    """
    :param tf_path:
    :param num_epochs:
    :return:
    """
    if self._tf_path is None:
        if isinstance(tf_path, str):
            self._tf_path = [tf_path]
        else:
            self._tf_path = tf_path

    if num_epochs is None:
        num_epochs = 1
    else:
        num_epochs = num_epochs
    print('file_name path : ', self._tf_path)
    print('The Number of Epochs is : ', num_epochs)
    file_name_queue = self.tf.train.string_input_producer(
        self._tf_path, num_epochs=num_epochs, name='FileNameQueue')
    # self.tf.add_to_collection()
    print(type(file_name_queue).__name__)
    return file_name_queue

我仍然不知道为什么我的批处理在1个纪元后就结束了!

0 个答案:

没有答案