彩色CNN,约128000张图像,内存错误

时间:2019-05-09 13:49:05

标签: python

我目前正在尝试实现CNN来对灰度图像进行着色。我一直在关注R. Zhang的报告,至少在某种程度上,一切似乎都在起作用。问题是我的代码基于使用5000张64 x 64像素的图像,也就是说,我当前正在读取numpy的维度数组(5000 x 64 x 64)。我想做的是使用大约128 000张图像代替(当然,这将需要更长的时间,但是将来会是个问题)。我还尝试对这些图像的颜色标签进行软编码,问题是当尝试使用128 000张图像时,我最终遇到了内存错误。因此,我的问题是,我可以对我的代码进行一些更改(使用tf占位符等),以避免出现此内存错误吗?

from __future__ import print_function
import pickle
from sklearn.neighbors import NearestNeighbors
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Conv2D, UpSampling2D, BatchNormalization, Dropout
from skimage import color
import cv2 as cv
from keras.regularizers import l2
import tensorflow as tf
from keras import backend as K

class network:
    def __init__(self, points_n, epochs, batch_size):
        #self.points_n = points_n
        self.interpolation_factor = 4
        self.x = self.load_databatch()
        #self.epochs = epochs
        #self.batch_size = batch_size
        self.Q = np.load('pts_in_hull.npy')
        #self.v = tf.convert_to_tensor(np.load('prior_probs.npy').astype(np.float32)) / 101.378494
        self.sigma = 5
        #self.Z, self.L = self.soft_encode()


    def unpickle(self, file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo)
        return dict

    def load_databatch(self):
        d = self.unpickle('train_data_batch_1')
        x = d['data']
        area_img = int((x.shape[1]) / 3)
        x = np.dstack((x[:, :area_img], x[:, area_img:2 * area_img], x[:, 2 * area_img:]))
        self.h = self.w = int(np.sqrt(area_img))
        x = (x.reshape(x.shape[0], self.h, self.w, 3))
        #print(x.shape)
        return x

    def gaussian_kernel(self, x):
        return np.exp(-0.5 * ((x**2)/(self.sigma**2)))

    def soft_encode(self):
        nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(self.Q)
        self.new_h, self.new_w = int(self.h / self.interpolation_factor), int(self.w / self.interpolation_factor)
        #Z = np.zeros((self.x.shape[0], self.new_h, self.new_w, self.Q.shape[0]))
        Z = []
        L = np.zeros((self.x.shape[0], self.h, self.w, 1))
        row_idx = np.arange(self.new_h * self.new_w).reshape(self.new_h * self.new_w, 1)
        for i in range(self.x.shape[0]):
            cie_img = color.rgb2lab(self.x[i])
            l = (cie_img[:, :, 0] -50) / 50
            l = l.reshape((cie_img.shape[0], cie_img.shape[1], 1))
            L[i] = l
            resize_img = cv.resize(cie_img, (self.new_h, self.new_w),
                            cv.INTER_AREA)
            ab_channel_img = resize_img[:, :,1:3].reshape((self.new_h * self.new_w, resize_img.shape[2]-1))
            distances, indices = nbrs.kneighbors(ab_channel_img)
            knn_Y = self.gaussian_kernel(distances)
            knn_Y = knn_Y / knn_Y.sum(axis = 1, keepdims = True)
            z = np.zeros((self.new_h * self.new_w, self.Q.shape[0]))
            z[row_idx, indices] = knn_Y
            z = z.reshape((self.new_h, self.new_w, self.Q.shape[0]))
            Z.append(z)
        #return np.array(Z), L
        print(np.array(Z).shape)

network1 = network(500, 1, 50)
#network1.load_databatch()
#network1.soft_encode(

0 个答案:

没有答案