我正在尝试重写 SSD (单发多盒检测器)Keras的库,以便在批处理构造期间读取 tfrecord 文件(在训练阶段)。
我创建一个 tfrecord ,其中每个元素都包含:
在进行keras训练期间,批量创建的方法如下:
def generate(self,
batch_size=32,
shuffle=True,
transformations=[],
label_encoder=None,
returns={'processed_images', 'encoded_labels'},
keep_images_without_gt=False,
degenerate_box_handling='remove'):
...
with tf.Session() as session:
while True:
batch_X, batch_y = [], []
if current >= self.dataset_size:
current = 0
...
tf_batch = self.tfrecord_dataset.shuffle(batch_size, reshuffle_each_iteration = True)
self.rewrite_fields(tf_batch, session, batch_size)
...
# I wrapped the self's fields in various variable
# (I also apply some image transformation)
ret = []
if 'processed_images' in returns: ret.append(batch_X)
if 'encoded_labels' in returns: ret.append(batch_y_encoded)
if 'matched_anchors' in returns: ret.append(batch_matched_anchors)
if 'processed_labels' in returns: ret.append(batch_y)
if 'filenames' in returns: ret.append(batch_filenames)
if 'image_ids' in returns: ret.append(batch_image_ids)
if 'evaluation-neutral' in returns: ret.append(batch_eval_neutral)
if 'inverse_transform' in returns: ret.append(batch_inverse_transforms)
if 'original_images' in returns: ret.append(batch_original_images)
if 'original_labels' in returns: ret.append(batch_original_labels)
# K.clear_session()
# session.graph.
# tf.initialize_all_variables()
if not (tf_batch is None):
del tf_batch
del self.images
self.images = None
self.labels = []
self.eval_neutral = []
self.image_ids = []
self.dataset_indices = []
yield ret
提取数据的方法rewrite_fields(tf_dataset_batch, session, batch_size)
是:
def rewrite_fields(self,
tf_dataset_batch,
session,
batch_size):
self.images = []
self.labels = []
self.image_ids = []
self.eval_neutral = []
# Iterate over every image
# tf_dataset_batch.map(self.map_tfrecord_feattures)
iterator = tf_dataset_batch.make_one_shot_iterator()
next_record = iterator.get_next()
# NOTE: If you use "tf.Dataset.batch()" you must take the first element of every
# field (es: For take the real image data you must perform "tf.decode_raw(image[0], tf.uint8)")
# AND NOT "tf.decode_raw(image, tf.uint8)"
# Iterate with a tensorflow-session
for index in range(batch_size):
image, labels, image_shape, labels_shape, image_id, eval_neutral = session.run(next_record)
# image, labels, image_shape, labels_shape, image_id, eval_neutral = next_record
# image, labels, image_shape, labels_shape, image_id, eval_neutral = self.map_tfrecord_feattures(next_record[0])
# Decode the fields
image_shape = tf.decode_raw(image_shape, tf.int32)
image_shape = image_shape.eval()
# image_shape_array = np.zeros(1)
# tf_image_shape_array = tf.placeholder(dtype=image_shape_array.dtype)
# image_shape = image_shape.eval(feed_dict={tf_image_shape_array: image_shape_array})
image = tf.decode_raw(image, tf.uint8)
image = image.eval()
# image_array = np.zeros(1)
# tf_image_array = tf.placeholder(dtype=image_array.dtype)
# image = image.eval(feed_dict={tf_image_array: image_array})
image = image.reshape(image_shape)
self.images.append(image)
labels_shape = tf.decode_raw(labels_shape, tf.int32)
labels_shape = labels_shape.eval()
# labels_shape_array = np.zeros(1)
# tf_labels_shape_array = tf.placeholder(dtype=labels_shape_array.dtype)
# labels_shape = labels_shape.eval(feed_dict={tf_labels_shape_array: labels_shape_array})
label = tf.decode_raw(labels, tf.int32)
label = label.eval()
# label_array = np.zeros(1)
# tf_label_array = tf.placeholder(dtype=label_array.dtype)
# label = label.eval(feed_dict={tf_label_array: label_array})
label = label.reshape(labels_shape)
eval_neutral = tf.decode_raw(eval_neutral, tf.uint8)
eval_neutral = eval_neutral.eval()
eval_neutral = eval_neutral.astype(bool)
# Add the elements
self.labels.append(label)
self.image_ids.append(image_id)
self.eval_neutral.append(eval_neutral)
self.dataset_indices = np.arange(self.dataset_size, dtype=np.int32)
通过launch model.fit_generator()
的上述生成器,我发现内存在每个时期增加了 2GB 。
我知道,对于每条读取和转换指令(例如将张量转换为numpy数组的image.eval()
),某些操作和数据都会分配到 tensorflow-default-graph
在注释中,有一些尝试使用feed_dict
来避免内存分配(但是它不起作用...)
如何避免疯狂分配内存?
注意:数据集的地图函数为:
@staticmethod
def map_tfrecord_feattures(data_record):
# Define all the record-features
features = {
'image': tf.FixedLenFeature([], tf.string),
'labels': tf.FixedLenFeature([], tf.string),
'image_shape': tf.FixedLenFeature([], tf.string),
'labels_shape': tf.FixedLenFeature([], tf.string),
'image_id': tf.FixedLenFeature([], tf.string),
'eval_neutral': tf.FixedLenFeature([], tf.string)
}
# Extract and cast every feature
sample = tf.parse_single_example(data_record, features)
image = sample['image']
labels = sample['labels']
image_shape = sample['image_shape']
labels_shape = sample['labels_shape']
image_id = sample['image_id']
eval_neutral = sample['eval_neutral']
# Return the values
return image, labels, image_shape, labels_shape, image_id, eval_neutral