使用plt.imread()优化内存使用

时间:2018-10-28 12:56:06

标签: python-3.x image numpy matplotlib

我正在处理大量图像(120k),每个图像是RGB +黄色的单个图像的组成部分(因此,我有30k唯一图像全部分解为4个图像:一个用于红色,绿色,蓝色和红色)。黄色)

对于每个图像ID,我将4个分量(RGB +黄色)合并为一个(M,N,4)数组(其中M和N是图像的尺寸)。

我使用以下代码:

import pandas as pd
import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from os import listdir

train_labels_data = pd.read_csv('/Documents/train.csv')

def merge_rgb(img_id, colours=['red','blue','green','yellow'], path = 'train'):

"""
For each images, returns an array of shape (M,N,4) 
where each dimension in the 4 are red, blue, green and yellow.
"""
    merged_colour_img = []

    for colour in colours:
        full_path = path + '/' + img_id + '_' + colour + '.png'
        colour_img=mpimg.imread(full_path)
        merged_colour_img.append(colour_img)    

    merged_colour_img = np.dstack((merged_colour_img))
    return merged_colour_img


def train_data_label(train_labels_data):

"""
From the train_labels csv file, create a list of labels, and create a large 
array for the train data in same order.
"""
    train_ids = [img_id for img_id in train_labels_data['Id']]
    train_labels = [label for label in train_labels_data['Target']]

    print ('Labels and Ids collected')

    train_data = []

    i=0
    for img_id in train_ids:

        print ('Merging Image')
        train_data_img = merge_rgb (img_id)
        print ('Merging done, appending the (M,N,4) array to a list')
        train_data.append(train_data_img)
        i += 1
        print ('Done appending, going to next image')
        print(i)

    print('Stacking all images in one big array')
    train_data = np.stack(train_data)

    return train_labels, train_data


train_labels, train_data = train_data_label(train_labels_data)


# SAVE OUTPUT

data_pickle_train = pickle.dumps(train_data)
data = open("/Documents/train_data.pkl","wb")
data.write(data_pickle_train)
data.close()

data_pickle_train_labels = pickle.dumps(train_labels)
data = open("/Documents/train_data_labels.pkl","wb")
data.write(data_pickle_train_labels)
data.close()

但是,此代码占用大量内存,并且在处理所有图像之前中途崩溃。由于我正在使用图像,因此我怀疑我可以改进merge_rgb函数,有什么建议吗?

谢谢

0 个答案:

没有答案