读取tfrecord导致内存不足

时间:2018-09-26 10:41:21

标签: python tensorflow keras

我正在尝试重写 SSD 单发多盒检测器)Keras的库,以便在批处理构造期间读取 tfrecord 文件(在训练阶段)。

我创建一个 tfrecord ,其中每个元素都包含:

  • 图片(作为 uint8 的数组)
  • 图片ID / 图片名称(作为 string
  • 边界框信息,其中包含 x_pos y_pos 宽度高度(作为 uint8 的数组。该数组包含每个边界框)
  • 图像形状作为图像的尺寸:宽度高度 num_channels (以3 uint16
  • 边界框形状作为边界框上一个字段中的尺寸和边界框数量: 5 x num_bounding_boxes (作为 uint16 的数组)
  • eval_neutral 一个 bool 数组(一个 uint8 数组)

在进行keras训练期间,批量创建的方法如下:

def generate(self,
             batch_size=32,
             shuffle=True,
             transformations=[],
             label_encoder=None,
             returns={'processed_images', 'encoded_labels'},
             keep_images_without_gt=False,
             degenerate_box_handling='remove'):

        ...

        with tf.Session() as session:

        while True:

            batch_X, batch_y = [], []

            if current >= self.dataset_size:
                current = 0

            ...

            tf_batch = self.tfrecord_dataset.shuffle(batch_size, reshuffle_each_iteration = True)
            self.rewrite_fields(tf_batch, session, batch_size)

            ...

            # I wrapped the self's fields in various variable 
            # (I also apply some image transformation)

            ret = []
            if 'processed_images' in returns: ret.append(batch_X)
            if 'encoded_labels' in returns: ret.append(batch_y_encoded)
            if 'matched_anchors' in returns: ret.append(batch_matched_anchors)
            if 'processed_labels' in returns: ret.append(batch_y)
            if 'filenames' in returns: ret.append(batch_filenames)
            if 'image_ids' in returns: ret.append(batch_image_ids)
            if 'evaluation-neutral' in returns: ret.append(batch_eval_neutral)
            if 'inverse_transform' in returns: ret.append(batch_inverse_transforms)
            if 'original_images' in returns: ret.append(batch_original_images)
            if 'original_labels' in returns: ret.append(batch_original_labels)

            # K.clear_session()
            # session.graph.
            # tf.initialize_all_variables()

            if not (tf_batch is None):
                del tf_batch
                del self.images
                self.images = None
                self.labels = []
                self.eval_neutral = []
                self.image_ids = []
                self.dataset_indices = []

            yield ret

提取数据的方法rewrite_fields(tf_dataset_batch, session, batch_size)是:

def rewrite_fields(self,
                   tf_dataset_batch,
                   session,
                   batch_size):

    self.images = []
    self.labels = []
    self.image_ids = []
    self.eval_neutral = []

    # Iterate over every image
    # tf_dataset_batch.map(self.map_tfrecord_feattures)
    iterator = tf_dataset_batch.make_one_shot_iterator()
    next_record = iterator.get_next()

    # NOTE: If you use "tf.Dataset.batch()" you must take the first element of every
    # field (es: For take the real image data you must perform "tf.decode_raw(image[0], tf.uint8)")
    # AND NOT "tf.decode_raw(image, tf.uint8)"

    # Iterate with a tensorflow-session
    for index in range(batch_size):

        image, labels, image_shape, labels_shape, image_id, eval_neutral = session.run(next_record)
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = next_record
        # image, labels, image_shape, labels_shape, image_id, eval_neutral = self.map_tfrecord_feattures(next_record[0])

        # Decode the fields
        image_shape = tf.decode_raw(image_shape, tf.int32)
        image_shape = image_shape.eval()
        # image_shape_array = np.zeros(1)
        # tf_image_shape_array = tf.placeholder(dtype=image_shape_array.dtype)
        # image_shape = image_shape.eval(feed_dict={tf_image_shape_array: image_shape_array})
        image = tf.decode_raw(image, tf.uint8)
        image = image.eval()
        # image_array = np.zeros(1)
        # tf_image_array = tf.placeholder(dtype=image_array.dtype)
        # image = image.eval(feed_dict={tf_image_array: image_array})
        image = image.reshape(image_shape)

        self.images.append(image)

        labels_shape = tf.decode_raw(labels_shape, tf.int32)
        labels_shape = labels_shape.eval()
        # labels_shape_array = np.zeros(1)
        # tf_labels_shape_array = tf.placeholder(dtype=labels_shape_array.dtype)
        # labels_shape = labels_shape.eval(feed_dict={tf_labels_shape_array: labels_shape_array})
        label = tf.decode_raw(labels, tf.int32)
        label = label.eval()
        # label_array = np.zeros(1)
        # tf_label_array = tf.placeholder(dtype=label_array.dtype)
        # label = label.eval(feed_dict={tf_label_array: label_array})
        label = label.reshape(labels_shape)

        eval_neutral = tf.decode_raw(eval_neutral, tf.uint8)
        eval_neutral = eval_neutral.eval()
        eval_neutral = eval_neutral.astype(bool)

        # Add the elements
        self.labels.append(label)
        self.image_ids.append(image_id)
        self.eval_neutral.append(eval_neutral)

    self.dataset_indices = np.arange(self.dataset_size, dtype=np.int32)

通过launch model.fit_generator()的上述生成器,我发现内存在每个时期增加了 2GB

我知道,对于每条读取和转换指令(例如将张量转换为numpy数组的image.eval()),某些操作和数据都会分配到 tensorflow-default-graph

在注释中,有一些尝试使用feed_dict来避免内存分配(但是它不起作用...)

如何避免疯狂分配内存?

注意:数据集的地图函数为:

@staticmethod
def map_tfrecord_feattures(data_record):

    # Define all the record-features
    features = {
        'image': tf.FixedLenFeature([], tf.string),
        'labels': tf.FixedLenFeature([], tf.string),
        'image_shape': tf.FixedLenFeature([], tf.string),
        'labels_shape': tf.FixedLenFeature([], tf.string),
        'image_id': tf.FixedLenFeature([], tf.string),
        'eval_neutral': tf.FixedLenFeature([], tf.string)
    }

    # Extract and cast every feature
    sample = tf.parse_single_example(data_record, features)

    image = sample['image']
    labels = sample['labels']
    image_shape = sample['image_shape']
    labels_shape = sample['labels_shape']
    image_id = sample['image_id']
    eval_neutral = sample['eval_neutral']

    # Return the values
    return image, labels, image_shape, labels_shape, image_id, eval_neutral

0 个答案:

没有答案