为什么在keras的预测图像中有条带?

时间:2017-08-29 07:05:39

标签: machine-learning neural-network deep-learning keras image-segmentation

在我的测试中,我在keras中进行了两类分割,以便从卫星图像中掩盖云。

两个样本用于培训,验证和测试,故意过度拟合。 像这样的条带预测图像非常奇怪(右边是预测标签。左边是图像和标签。): enter image description here enter image description here

我的代码在这里。我的代码出了什么问题,或者是keras的问题?

#coding=utf-8
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import img_to_array
from keras.optimizers import SGD
from keras.optimizers import RMSprop,Adadelta,Adagrad,Adam
from keras.wrappers.scikit_learn import KerasClassifier
from keras.callbacks import ModelCheckpoint,LearningRateScheduler
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import matplotlib.pyplot as plt
from libtiff import TIFF
from skimage import exposure
from keras import backend as k
from keras.callbacks import TensorBoard
import pandas as pd
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))

seed = 7  
np.random.seed(seed)  

#data_shape = 360*480  
img_w = 256  
img_h = 256  

n_label = 2  

classes = [0.  ,  1.]  

labelencoder = LabelEncoder()  
labelencoder.fit(classes)  

def load_img(path, grayscale=False, target_size=None):  
    img = Image.open(path)  
    if grayscale:  
        if img.mode != 'L':  
            img = img.convert('L')  
    if target_size:  
        wh_tuple = (target_size[1], target_size[0])  
        if img.size != wh_tuple:  
            img = img.resize(wh_tuple)  
    return img  

train_url = open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r').readlines()  
#trainval_url = open(r'/media/wmy/document/BigData/VOCdevkit/VOC2012/ImageSets/Segmentation/trainval.txt','r').readlines()  
val_url = open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r').readlines()  
train_numb = len(train_url)  
valid_numb = len(val_url)  
print "the number of train data is",train_numb  
print "the number of val data is",valid_numb  

directory = '/media/private/yanyuan/yan/image_block/image256/'

def generateData(batch_size):
    with open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r') as f:
        train_url = [line.strip('\n') for line in f.readlines()]

    while True:
        train_data = []
        train_label = []
        batch = 0
        for url in train_url:
            batch += 1
#            img = load_img(filepath + 'JPEGImages/' + url.strip('\n') + '.jpg', target_size=(img_w, img_h))
#            img = img_to_array(img)
            tif = TIFF.open(directory + 'images/' + url + '.tiff')
            img = tif.read_image()
            mean_vec = np.array([456,495,440,446],dtype=np.float32)
            mean_vec = mean_vec.reshape(1,1,4)
            TIFF.close(tif)
            img = np.array(img, dtype=np.float32)
            img = img - mean_vec
            img *= 1.525902189e-5
            # print img.shape
            train_data.append(img)
#            label = load_img(filepath + 'SegmentationClass/' + url.strip('\n') + '.png', target_size=(img_w, img_h))
            label = load_img(directory + 'labels/' + url + '.png', target_size=(img_w, img_h))            
            label = img_to_array(label).reshape((img_w * img_h,))
            # print label.shape
            train_label.append(label)
            if batch % batch_size==0:
                train_data = np.array(train_data)
                train_label = np.array(train_label).flatten()
                train_label = labelencoder.transform(train_label)
                train_label = to_categorical(train_label, num_classes=n_label)
                train_label = train_label.reshape((batch_size,img_w * img_h,n_label))
                yield (train_data,train_label)
                train_data = []
                train_label = []
                batch = 0

def generateValidData(batch_size):

    with open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r') as f:
        val_url = [line.strip('\n') for line in f.readlines()]
    while True:
        valid_data = []
        valid_label = []
        batch = 0
        for url in val_url:
            batch += 1
            #img = load_img(filepath + 'JPEGImages/' + url.strip('\n') + '.jpg', target_size=(img_w, img_h))
            #img = img_to_array(img)
            tif = TIFF.open(directory + 'images/' + url + '.tiff')
            img = tif.read_image()
            mean_vec = np.array([456,495,440,446],dtype=np.float32)
            mean_vec = mean_vec.reshape(1,1,4)
            TIFF.close(tif)
            img = np.array(img, dtype=np.float32) 
            img = img - mean_vec
            img *= 1.525902189e-5
            #print(img.shape)

            # print img.shape
            valid_data.append(img)
            label = load_img(directory + 'labels/' + url + '.png', target_size=(img_w, img_h))
            label = img_to_array(label).reshape((img_w * img_h,))
            # print label.shape
            valid_label.append(label)
            if batch % batch_size==0:
                valid_data = np.array(valid_data)
                valid_label = np.array(valid_label).flatten()
                valid_label = labelencoder.transform(valid_label)
                valid_label = to_categorical(valid_label, num_classes=n_label)
                valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))
                yield (valid_data,valid_label)
                valid_data = []
                valid_label = []
                batch = 0  
def SegNet():
    model = Sequential()
    #encoder
    model.add(Conv2D(64,(7,7),strides=(1,1),input_shape=(img_w,img_h,4),padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    #(128,128)
    model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    #(64,64)
    model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu')) 
    model.add(MaxPooling2D(pool_size=(2, 2)))
    #(32,32)
    model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
    model.add(BatchNormalization())    
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    #(16,16)

    #decoder
    model.add(UpSampling2D(size=(2,2)))
    #(16,16)
    model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
    model.add(BatchNormalization())

    model.add(UpSampling2D(size=(2, 2)))
    #(32,32)
    model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
    model.add(BatchNormalization())

    model.add(UpSampling2D(size=(2, 2)))
    #(64,64)
    model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
    model.add(BatchNormalization())

    model.add(UpSampling2D(size=(2, 2)))
    #(128,128)
    model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
    model.add(BatchNormalization())

    #(256,256)

    model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same'))
    model.add(Reshape((n_label,img_w*img_h)))

    model.add(Permute((2,1)))
    model.add(Activation('softmax'))
    sgd=SGD(lr=0.1,momentum=0.95,decay=0.0005,nesterov=False)
    adam = Adam(lr=0.001,beta_1=0.9,beta_2=0.999,decay=0.0005)
    #model.compile(loss='categorical_crossentropy',optimizer=sgd,metrics=['accuracy'])
    model.compile(loss=keras.losses.categorical_crossentropy,optimizer=sgd,metrics=['accuracy'])
    model.summary()
    return model 


def train():
    k.set_learning_phase(1)
    model = SegNet()  
    modelcheck = ModelCheckpoint('Segnet_params.h5',monitor='val_acc',save_best_only=True,mode='auto')  
    callable = [modelcheck]  

    history = model.fit_generator(verbose=2,
                                  generator=generateData(2),
                                  validation_data=generateValidData(2),
                                  steps_per_epoch=1, 
                                  epochs=600, 
                                  callbacks=callable, 
                                  max_queue_size=1, 
                                  class_weight = None,
                                  validation_steps=1)

    drawLoss(history)

def drawLoss(history):
    plt.figure()

    plt.plot(history.history['acc'],'g')
    plt.plot(history.history['val_acc'],'r')
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.show()
    # summarize history for loss
    plt.figure()

    plt.plot(history.history['loss'],'g')
    plt.plot(history.history['val_loss'],'r')
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()
def predict():
    k.set_learning_phase(0)
    model = SegNet()
    model.load_weights('Segnet_params.h5')
    file = open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r')
    train_url = [line.strip('\n') for line in file.readlines()]
    for url in train_url:
        tif = TIFF.open(directory + 'images/' + url + '.tiff')
        img = tif.read_image()
        TIFF.close(tif)
        mean_vec = np.array([456,495,440,446],dtype=np.float32)
        mean_vec = mean_vec.reshape(1,1,4)
        img = np.array(img, dtype=np.float32) 
        img = img - mean_vec
        img *= 1.525902189e-5
        im = np.empty_like(img)

        for j in range(4):

            l_val,r_val = np.percentile(img[:,:,j],(2,98),interpolation='linear')
            im[:,:,j] = exposure.rescale_intensity(img[:,:,j], in_range=(l_val,r_val),out_range='uint8')

        im = im[:,:,(2,1,0)]
        im = im.astype(np.uint8)
        img = img.reshape(1,img_h,img_w,-1)
        pred = model.predict_classes(img,verbose=2)
        pred = labelencoder.inverse_transform(pred[0])
        print np.unique(pred)
        pred = pred.reshape((img_h,img_w)).astype(np.uint8)
        pred_img = Image.fromarray(pred)
#        pred_img.save('1.png',format='png')
        label = load_img(directory + 'labels/' + url + '.png', target_size=(img_w, img_h))
        #print pred
        plt.figure()
        plt.subplot(2,2,1)
        plt.imshow(im)
        plt.subplot(2,2,2)
        plt.imshow(pred)
        plt.subplot(2,2,3)
        plt.imshow(label)
        plt.show()

if __name__=='__main__':
    train()
    predict()

提前致谢。

0 个答案:

没有答案