将卫星图像转换为numpy数组时的内存错误

时间:2017-11-15 17:32:10

标签: python arrays image memory bigdata

我有1GB大小的图像,我正在尝试将其用作深度学习中的训练数据。但是我将这个大图像转换为numpy数组时出现内存错误。我是python的新手,请帮我解决这个问题。 (使用16GB内存的笔记本电脑)

[here] [1]是代码

    import numpy as np
    import os,glob,datetime,random,time
    from joblib import Parallel,delayed
    import matplotlib.pyplot as plt
    plt.switch_backend('agg')
    from sklearn.cross_validation import train_test_split
    ####################################################
    def get_img_from_win(img,win,img_row,img_col):
        if img.ndim==3:
            imgs=[]
            for win_index in range (0,len(win)):
                win_c= win[win_index]
                imgs.append(img_extend(img[win_c[1][0]:win_c[1][1],win_c[0][0]:win_c[0][1],:],img_row,img_col))
        if img.ndim==2:
            imgs=[]
            for win_index in range (0,len(win)):
                win_c= win[win_index]
                imgs.append(img[win_c[1][0]:win_c[1][1],win_c[0][0]:win_c[0][1]])
        return imgs

    def get_train_win(img):
        win=[]
        win.append(np.array([[0,img.shape[1]],[0,img.shape[0]]]))
        return win,len(win)

    def draw_win_in_ori(img,win,size):
        img_n=np.copy(img)
        for win_index in range(0,len(win)):
            win_c=win[win_index]
            img_n[win_c[1][0]:win_c[1][0]+size,win_c[0][0]:win_c[0][1],:]=0
            img_n[win_c[1][1]-size:win_c[1][1],win_c[0][0]:win_c[0][1],:]=0
            img_n[win_c[1][0]:win_c[1][1],win_c[0][0]:win_c[0][0]+size,:]=0
            img_n[win_c[1][0]:win_c[1][1],win_c[0][1]-size:win_c[0][1],:]=0

        plt.imshow(img_n)

plt.imsave(os.path.join(result_file, 'win_img _' + SCRIPT_NAME + “_” + file_time_global + 'PNG'),img_n)

    def img_extend(img,img_rows,img_cols):
        half=int(np.floor(img_cols/2))
        part1=np.zeros((half,half,3),dtype='uint8')
        part2=np.zeros((half,img.shape[1],3),dtype='uint8')
        part3=np.zeros((half,half,3),dtype='uint8')
        part4=np.zeros((img.shape[0],half,3),dtype='uint8')
        part5=np.zeros((img.shape[0],half,3),dtype='uint8')
        part6=np.zeros((half,half,3),dtype='uint8')
        part7=np.zeros((half,img.shape[1],3),dtype='uint8')
        part8=np.zeros((half,half,3),dtype='uint8')
        part1[:,:,:]=img[0,0,:]
        part2[:,:,:]=np.reshape(img[0,:,:],(1,img.shape[1],3))
        part3[:,:,:]=img[0,img.shape[1]-1,:]
        part4[:,:,:]=np.reshape(img[:,0,:],(img.shape[0],1,3))
        part5[:,:,:]=np.reshape(img[:,img.shape[1]-1,:],(img.shape[0],1,3))
        part6[:,:,:]=img[img.shape[0]-1,0,:]
        part7[:,:,:]=np.reshape(img[img.shape[0]-1,:,:],(1,img.shape[1],3))
        part8[:,:,:]=img[img.shape[0]-1,img.shape[1]-1,:]

        img_up_all=np.zeros((half,half+half+img.shape[1],3),dtype='uint8')
        img_mid_all=np.zeros((img.shape[0],half+half+img.shape[1],3),dtype='uint8')
        img_down_all=np.zeros((half,half+half+img.shape[1],3),dtype='uint8')
        img_all=np.zeros((img.shape[0]+half*2,img.shape[1]+half*2,3),dtype='uint8')

        img_up_all=np.concatenate((part1,part2,part3),axis=1)
        img_mid_all=np.concatenate((part4,img,part5),axis=1)
        img_down_all=np.concatenate((part6,part7,part8),axis=1)
        img_all=np.concatenate((img_up_all,img_mid_all,img_down_all),axis=0)
        return img_all


    def get_data_from_img(img,img_bin,img_rows,img_cols,Samp_Max,case):
        if case=='train':
            data=[]
            pos=[]
            neg=[]
            label=[]
            pos_label=[]
            neg_label=[]
            half=int(np.floor(img_cols/2))
            for row in range (half,img.shape[0]-half):
                for col in range (half,img.shape[1]-half):
                    if img_bin[row-half][col-half]==255:
                        pos.append(img[row-half:row+half,col-half:col+half,:])
                        pos_label.append(img_bin[row-half][col-half])
                    else:
                        neg.append(img[row-half:row+half,col-half:col+half,:])
                        neg_label.append(img_bin[row-half][col-half])
            if len(neg)>Samp_Max:
                rand=random.sample(range(len(neg)),Samp_Max)
                neg_new=[]
                label_new=[]
                neg_new=[neg[x] for x in rand]
                label_new=[neg_label[x] for x in rand]
                neg=neg_new
                neg_label=label_new
            data.extend(pos)
            data.extend(neg)
            label.extend(pos_label)
            label.extend(neg_label)
        return data,label

    def reshape_data(data):
        #reshape train or test images into  4D matric [number_of_samples,channels,rows,cols]

        data=np.transpose(data,(0,3,1,2))
        #convert dtype from int8 into float32
        data = data.astype('float32')
        #assign values to [0,1]
        data = (data-127)/128

        return data
    ###########  file colation and globle V  ###########
    img_rows=18
    img_cols=18
    n_procs=10
    Script_name=a=os.path.basename(__file__)
    data_file=os.path.join('..','data')
    now_global=datetime.datetime.now()
    file_time_global=str(now_global.strftime("%Y-%m-%d-%H-%M"))

    ##############################  prepare data  ###########################
    ori_img=plt.imread(os.path.join(data_file,'ori_train.tif'))
    ori_bin_img=plt.imread(os.path.join(data_file,'bin_train.tif'))
    ori_bin_img = ori_bin_img * 255
    win ,nb_win=get_train_win(ori_img)

    draw_win_in_ori(ori_img,win,3)

    imgs=get_img_from_win(ori_img,win,img_rows,img_cols)
    imgs_bin=get_img_from_win(ori_bin_img,win,img_rows,img_cols)

    ############################ get train data and label ############################
    X_T=[]
    y_T=[]
    for i in range (0,len(imgs)):
        Train_data,Train_label=get_data_from_img(imgs[i],imgs_bin  [i],img_rows,img_cols,50000,'train')
        X_T.extend(Train_data)
        y_T.extend(Train_label)

    shuffle=random.sample(range(len(X_T)),len(X_T))
    train_data=[]
    train_data=[X_T[x] for x in shuffle]
    labels=[]
    labels=[y_T[x] for x in shuffle]

    train_data=np.array(train_data)
    train_data=reshape_data(train_data)

    labels=np.array(labels)/255
    X_train,X_valid,y_train,y_valid=train_test_split(train_data,labels,test_size=0.2,random_state=42)

0 个答案:

没有答案