在我的测试中,我在keras中进行了两类分割,以便从卫星图像中掩盖云。
两个样本用于培训,验证和测试,故意过度拟合。 像这样的条带预测图像非常奇怪(右边是预测标签。左边是图像和标签。):
我的代码在这里。我的代码出了什么问题,或者是keras的问题?
#coding=utf-8
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import img_to_array
from keras.optimizers import SGD
from keras.optimizers import RMSprop,Adadelta,Adagrad,Adam
from keras.wrappers.scikit_learn import KerasClassifier
from keras.callbacks import ModelCheckpoint,LearningRateScheduler
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import matplotlib.pyplot as plt
from libtiff import TIFF
from skimage import exposure
from keras import backend as k
from keras.callbacks import TensorBoard
import pandas as pd
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))
seed = 7
np.random.seed(seed)
#data_shape = 360*480
img_w = 256
img_h = 256
n_label = 2
classes = [0. , 1.]
labelencoder = LabelEncoder()
labelencoder.fit(classes)
def load_img(path, grayscale=False, target_size=None):
img = Image.open(path)
if grayscale:
if img.mode != 'L':
img = img.convert('L')
if target_size:
wh_tuple = (target_size[1], target_size[0])
if img.size != wh_tuple:
img = img.resize(wh_tuple)
return img
train_url = open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r').readlines()
#trainval_url = open(r'/media/wmy/document/BigData/VOCdevkit/VOC2012/ImageSets/Segmentation/trainval.txt','r').readlines()
val_url = open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r').readlines()
train_numb = len(train_url)
valid_numb = len(val_url)
print "the number of train data is",train_numb
print "the number of val data is",valid_numb
directory = '/media/private/yanyuan/yan/image_block/image256/'
def generateData(batch_size):
with open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r') as f:
train_url = [line.strip('\n') for line in f.readlines()]
while True:
train_data = []
train_label = []
batch = 0
for url in train_url:
batch += 1
# img = load_img(filepath + 'JPEGImages/' + url.strip('\n') + '.jpg', target_size=(img_w, img_h))
# img = img_to_array(img)
tif = TIFF.open(directory + 'images/' + url + '.tiff')
img = tif.read_image()
mean_vec = np.array([456,495,440,446],dtype=np.float32)
mean_vec = mean_vec.reshape(1,1,4)
TIFF.close(tif)
img = np.array(img, dtype=np.float32)
img = img - mean_vec
img *= 1.525902189e-5
# print img.shape
train_data.append(img)
# label = load_img(filepath + 'SegmentationClass/' + url.strip('\n') + '.png', target_size=(img_w, img_h))
label = load_img(directory + 'labels/' + url + '.png', target_size=(img_w, img_h))
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
train_label.append(label)
if batch % batch_size==0:
train_data = np.array(train_data)
train_label = np.array(train_label).flatten()
train_label = labelencoder.transform(train_label)
train_label = to_categorical(train_label, num_classes=n_label)
train_label = train_label.reshape((batch_size,img_w * img_h,n_label))
yield (train_data,train_label)
train_data = []
train_label = []
batch = 0
def generateValidData(batch_size):
with open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r') as f:
val_url = [line.strip('\n') for line in f.readlines()]
while True:
valid_data = []
valid_label = []
batch = 0
for url in val_url:
batch += 1
#img = load_img(filepath + 'JPEGImages/' + url.strip('\n') + '.jpg', target_size=(img_w, img_h))
#img = img_to_array(img)
tif = TIFF.open(directory + 'images/' + url + '.tiff')
img = tif.read_image()
mean_vec = np.array([456,495,440,446],dtype=np.float32)
mean_vec = mean_vec.reshape(1,1,4)
TIFF.close(tif)
img = np.array(img, dtype=np.float32)
img = img - mean_vec
img *= 1.525902189e-5
#print(img.shape)
# print img.shape
valid_data.append(img)
label = load_img(directory + 'labels/' + url + '.png', target_size=(img_w, img_h))
label = img_to_array(label).reshape((img_w * img_h,))
# print label.shape
valid_label.append(label)
if batch % batch_size==0:
valid_data = np.array(valid_data)
valid_label = np.array(valid_label).flatten()
valid_label = labelencoder.transform(valid_label)
valid_label = to_categorical(valid_label, num_classes=n_label)
valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))
yield (valid_data,valid_label)
valid_data = []
valid_label = []
batch = 0
def SegNet():
model = Sequential()
#encoder
model.add(Conv2D(64,(7,7),strides=(1,1),input_shape=(img_w,img_h,4),padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
#(128,128)
model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
#(64,64)
model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
#(32,32)
model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
#(16,16)
#decoder
model.add(UpSampling2D(size=(2,2)))
#(16,16)
model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
#(32,32)
model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
#(64,64)
model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
model.add(UpSampling2D(size=(2, 2)))
#(128,128)
model.add(Conv2D(64, (7, 7), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
#(256,256)
model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same'))
model.add(Reshape((n_label,img_w*img_h)))
model.add(Permute((2,1)))
model.add(Activation('softmax'))
sgd=SGD(lr=0.1,momentum=0.95,decay=0.0005,nesterov=False)
adam = Adam(lr=0.001,beta_1=0.9,beta_2=0.999,decay=0.0005)
#model.compile(loss='categorical_crossentropy',optimizer=sgd,metrics=['accuracy'])
model.compile(loss=keras.losses.categorical_crossentropy,optimizer=sgd,metrics=['accuracy'])
model.summary()
return model
def train():
k.set_learning_phase(1)
model = SegNet()
modelcheck = ModelCheckpoint('Segnet_params.h5',monitor='val_acc',save_best_only=True,mode='auto')
callable = [modelcheck]
history = model.fit_generator(verbose=2,
generator=generateData(2),
validation_data=generateValidData(2),
steps_per_epoch=1,
epochs=600,
callbacks=callable,
max_queue_size=1,
class_weight = None,
validation_steps=1)
drawLoss(history)
def drawLoss(history):
plt.figure()
plt.plot(history.history['acc'],'g')
plt.plot(history.history['val_acc'],'r')
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
# summarize history for loss
plt.figure()
plt.plot(history.history['loss'],'g')
plt.plot(history.history['val_loss'],'r')
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
def predict():
k.set_learning_phase(0)
model = SegNet()
model.load_weights('Segnet_params.h5')
file = open(r'/media/private/yanyuan/yan/image4band16bit/train104/test4.txt','r')
train_url = [line.strip('\n') for line in file.readlines()]
for url in train_url:
tif = TIFF.open(directory + 'images/' + url + '.tiff')
img = tif.read_image()
TIFF.close(tif)
mean_vec = np.array([456,495,440,446],dtype=np.float32)
mean_vec = mean_vec.reshape(1,1,4)
img = np.array(img, dtype=np.float32)
img = img - mean_vec
img *= 1.525902189e-5
im = np.empty_like(img)
for j in range(4):
l_val,r_val = np.percentile(img[:,:,j],(2,98),interpolation='linear')
im[:,:,j] = exposure.rescale_intensity(img[:,:,j], in_range=(l_val,r_val),out_range='uint8')
im = im[:,:,(2,1,0)]
im = im.astype(np.uint8)
img = img.reshape(1,img_h,img_w,-1)
pred = model.predict_classes(img,verbose=2)
pred = labelencoder.inverse_transform(pred[0])
print np.unique(pred)
pred = pred.reshape((img_h,img_w)).astype(np.uint8)
pred_img = Image.fromarray(pred)
# pred_img.save('1.png',format='png')
label = load_img(directory + 'labels/' + url + '.png', target_size=(img_w, img_h))
#print pred
plt.figure()
plt.subplot(2,2,1)
plt.imshow(im)
plt.subplot(2,2,2)
plt.imshow(pred)
plt.subplot(2,2,3)
plt.imshow(label)
plt.show()
if __name__=='__main__':
train()
predict()
提前致谢。