我有1GB大小的图像,我正在尝试将其用作深度学习中的训练数据。但是我将这个大图像转换为numpy数组时出现内存错误。我是python的新手,请帮我解决这个问题。 (使用16GB内存的笔记本电脑)
[here] [1]是代码
import numpy as np
import os,glob,datetime,random,time
from joblib import Parallel,delayed
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from sklearn.cross_validation import train_test_split
####################################################
def get_img_from_win(img,win,img_row,img_col):
if img.ndim==3:
imgs=[]
for win_index in range (0,len(win)):
win_c= win[win_index]
imgs.append(img_extend(img[win_c[1][0]:win_c[1][1],win_c[0][0]:win_c[0][1],:],img_row,img_col))
if img.ndim==2:
imgs=[]
for win_index in range (0,len(win)):
win_c= win[win_index]
imgs.append(img[win_c[1][0]:win_c[1][1],win_c[0][0]:win_c[0][1]])
return imgs
def get_train_win(img):
win=[]
win.append(np.array([[0,img.shape[1]],[0,img.shape[0]]]))
return win,len(win)
def draw_win_in_ori(img,win,size):
img_n=np.copy(img)
for win_index in range(0,len(win)):
win_c=win[win_index]
img_n[win_c[1][0]:win_c[1][0]+size,win_c[0][0]:win_c[0][1],:]=0
img_n[win_c[1][1]-size:win_c[1][1],win_c[0][0]:win_c[0][1],:]=0
img_n[win_c[1][0]:win_c[1][1],win_c[0][0]:win_c[0][0]+size,:]=0
img_n[win_c[1][0]:win_c[1][1],win_c[0][1]-size:win_c[0][1],:]=0
plt.imshow(img_n)
plt.imsave(os.path.join(result_file, 'win_img _' + SCRIPT_NAME + “_” + file_time_global + 'PNG'),img_n)
def img_extend(img,img_rows,img_cols):
half=int(np.floor(img_cols/2))
part1=np.zeros((half,half,3),dtype='uint8')
part2=np.zeros((half,img.shape[1],3),dtype='uint8')
part3=np.zeros((half,half,3),dtype='uint8')
part4=np.zeros((img.shape[0],half,3),dtype='uint8')
part5=np.zeros((img.shape[0],half,3),dtype='uint8')
part6=np.zeros((half,half,3),dtype='uint8')
part7=np.zeros((half,img.shape[1],3),dtype='uint8')
part8=np.zeros((half,half,3),dtype='uint8')
part1[:,:,:]=img[0,0,:]
part2[:,:,:]=np.reshape(img[0,:,:],(1,img.shape[1],3))
part3[:,:,:]=img[0,img.shape[1]-1,:]
part4[:,:,:]=np.reshape(img[:,0,:],(img.shape[0],1,3))
part5[:,:,:]=np.reshape(img[:,img.shape[1]-1,:],(img.shape[0],1,3))
part6[:,:,:]=img[img.shape[0]-1,0,:]
part7[:,:,:]=np.reshape(img[img.shape[0]-1,:,:],(1,img.shape[1],3))
part8[:,:,:]=img[img.shape[0]-1,img.shape[1]-1,:]
img_up_all=np.zeros((half,half+half+img.shape[1],3),dtype='uint8')
img_mid_all=np.zeros((img.shape[0],half+half+img.shape[1],3),dtype='uint8')
img_down_all=np.zeros((half,half+half+img.shape[1],3),dtype='uint8')
img_all=np.zeros((img.shape[0]+half*2,img.shape[1]+half*2,3),dtype='uint8')
img_up_all=np.concatenate((part1,part2,part3),axis=1)
img_mid_all=np.concatenate((part4,img,part5),axis=1)
img_down_all=np.concatenate((part6,part7,part8),axis=1)
img_all=np.concatenate((img_up_all,img_mid_all,img_down_all),axis=0)
return img_all
def get_data_from_img(img,img_bin,img_rows,img_cols,Samp_Max,case):
if case=='train':
data=[]
pos=[]
neg=[]
label=[]
pos_label=[]
neg_label=[]
half=int(np.floor(img_cols/2))
for row in range (half,img.shape[0]-half):
for col in range (half,img.shape[1]-half):
if img_bin[row-half][col-half]==255:
pos.append(img[row-half:row+half,col-half:col+half,:])
pos_label.append(img_bin[row-half][col-half])
else:
neg.append(img[row-half:row+half,col-half:col+half,:])
neg_label.append(img_bin[row-half][col-half])
if len(neg)>Samp_Max:
rand=random.sample(range(len(neg)),Samp_Max)
neg_new=[]
label_new=[]
neg_new=[neg[x] for x in rand]
label_new=[neg_label[x] for x in rand]
neg=neg_new
neg_label=label_new
data.extend(pos)
data.extend(neg)
label.extend(pos_label)
label.extend(neg_label)
return data,label
def reshape_data(data):
#reshape train or test images into 4D matric [number_of_samples,channels,rows,cols]
data=np.transpose(data,(0,3,1,2))
#convert dtype from int8 into float32
data = data.astype('float32')
#assign values to [0,1]
data = (data-127)/128
return data
########### file colation and globle V ###########
img_rows=18
img_cols=18
n_procs=10
Script_name=a=os.path.basename(__file__)
data_file=os.path.join('..','data')
now_global=datetime.datetime.now()
file_time_global=str(now_global.strftime("%Y-%m-%d-%H-%M"))
############################## prepare data ###########################
ori_img=plt.imread(os.path.join(data_file,'ori_train.tif'))
ori_bin_img=plt.imread(os.path.join(data_file,'bin_train.tif'))
ori_bin_img = ori_bin_img * 255
win ,nb_win=get_train_win(ori_img)
draw_win_in_ori(ori_img,win,3)
imgs=get_img_from_win(ori_img,win,img_rows,img_cols)
imgs_bin=get_img_from_win(ori_bin_img,win,img_rows,img_cols)
############################ get train data and label ############################
X_T=[]
y_T=[]
for i in range (0,len(imgs)):
Train_data,Train_label=get_data_from_img(imgs[i],imgs_bin [i],img_rows,img_cols,50000,'train')
X_T.extend(Train_data)
y_T.extend(Train_label)
shuffle=random.sample(range(len(X_T)),len(X_T))
train_data=[]
train_data=[X_T[x] for x in shuffle]
labels=[]
labels=[y_T[x] for x in shuffle]
train_data=np.array(train_data)
train_data=reshape_data(train_data)
labels=np.array(labels)/255
X_train,X_valid,y_train,y_valid=train_test_split(train_data,labels,test_size=0.2,random_state=42)