使用发电机后,训练精度保持恒定

时间:2020-08-17 14:37:24

标签: tensorflow keras generator image-segmentation

我正在使用UNet进行图像分割。我从单波段图像开始,并使用以下代码训练了模型:

from glob import glob
from PIL import Image
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models 

output_path = '/output/path/'
input_img = sorted(glob('/input_img/path/*.tif'))
input_mask = sorted(glob('/input_mask/path/*.tif'))

#split data
img, img_val, mask, mask_val = train_test_split(input_img, input_mask, test_size=0.2, random_state=42)

#create lists of training and validation array
train_image = []
for i in img:
  img_arr = np.zeros((512,512,1))
  read_arr = np.array(Image.open(i))
  read_arr[np.isnan(read_arr)]=0
  img_arr[:,:,0]=read_arr
  train_image.append(img_arr)

train_mask = []
for i in mask:
  mask_arr = np.zeros((512,512,1))
  read_arr = np.array(Image.open(i))
  read_arr[np.isnan(read_arr)]=0
  mask_arr[:,:,0]=read_arr
  train_mask.append(mask_arr)

test_image = []
for i in img_val:
  img_arr = np.zeros((512,512,1))
  read_arr = np.array(Image.open(i))
  read_arr[np.isnan(read_arr)]=0
  img_arr[:,:,0]=read_arr
  test_image.append(img_arr)

test_mask = []
for i in mask_val:
  mask_arr = np.zeros((512,512,1))
  read_arr = np.array(Image.open(i))
  read_arr[np.isnan(read_arr)]=0
  mask_arr[:,:,0]=read_arr
  test_mask.append(mask_arr)

# tensorflow format
train= tf.data.Dataset.from_tensor_slices((train_image, train_mask))
test = tf.data.Dataset.from_tensor_slices((test_image, test_mask))

train_length = len(train_image)
img_shape = (512,512,1)
batch_size = 10
epochs = 200

train_dataset = train.cache().shuffle(train_length).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(batch_size).repeat()

# Build model
def conv_block(input_tensor, num_filters):
    encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
    encoder = layers.BatchNormalization()(encoder)
    encoder = layers.Activation('relu')(encoder)
    encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
    encoder = layers.BatchNormalization()(encoder)
    encoder = layers.Activation('relu')(encoder)
    return encoder

def encoder_block(input_tensor, num_filters):
    encoder = conv_block(input_tensor, num_filters)
    encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
  
    return encoder_pool, encoder

def decoder_block(input_tensor, concat_tensor, num_filters):
    decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
    decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    return decoder

inputs = layers.Input(shape=img_shape)
encoder0_pool, encoder0 = encoder_block(inputs, 32)
encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64)
encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128)
encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256)
encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512)

center = conv_block(encoder4_pool, 1024)

decoder4 = decoder_block(center, encoder4, 512)
decoder3 = decoder_block(decoder4, encoder3, 256)
decoder2 = decoder_block(decoder3, encoder2, 128)
decoder1 = decoder_block(decoder2, encoder1, 64)
decoder0 = decoder_block(decoder1, encoder0, 32)
outputs = layers.Conv2D(2, (1, 1), activation='sigmoid')(decoder0)

model = models.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Callbacks
save_model_path = '/tmp/dir/tmp.hdf5'
cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_model_path, monitor='val_loss', mode='min', save_best_only=True)
def scheduler(epoch,lr):
  if epoch < 151:
    return lr
  else:
    return lr*tf.math.exp(-0.1)
lrs=tf.keras.callbacks.LearningRateScheduler(scheduler)

# fit
history = model.fit(train_dataset, 
                   steps_per_epoch=int(np.ceil(train_length / float(batch_size))),
                   epochs=epochs,
                   validation_data=test_dataset,
                   validation_steps=int(np.ceil(len(test_image) / float(batch_size))),
                   callbacks=[cp,lrs])
model.save(output_path+'model.h5')

脚本运行良好,并产生如下训练曲线: Training curve without generator

接下来,我将处理9波段光谱图像。由于一次读取所有数据会使colab会话崩溃,因此我构建了一个生成器来进行批处理。脚本如下:

from glob import glob
from sklearn.model_selection import train_test_split
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import rasterio
from tensorflow import keras
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from keras.callbacks import ModelCheckpoint
from keras.callbacks import LearningRateScheduler

output_path = '/output/path/'
input_img_path = '/input_img/path/'
input_mask_path = '/input_mask/path/'

# img reading function
def img_load(path):
  img=rasterio.open(path)
  arr_img = np.zeros((512,512,1))
  arr_img[:,:,0]=img.read(1)
  arr_img[np.isnan(arr_img)]=0
  return arr_img

def mask_load(path):
  mask=rasterio.open(path)
  arr_mask = np.zeros((512,512,1))
  arr_mask[:,:,0]=mask.read(1)
  arr_mask[np.isnan(arr_mask)]=0
  return arr_mask

#generator
class DataGen(keras.utils.Sequence):
  def __init__(self, files, img_path, mask_path, batch_size=8):
    self.files=files
    self.img_path=img_path
    self.mask_path=mask_path
    self.batch_size=batch_size
    self.on_epoch_end()

  def __load__(self, files_name):
    img_path = os.path.join(self.img_path, files_name)
    mask_path=os.path.join(self.mask_path, files_name)
    arr_img = img_load(img_path)
    arr_mask = mask_load(mask_path)
    return arr_img, arr_mask

  def __getitem__(self, index):
    if (index+1)*self.batch_size > len(self.files):
      self.batch_size = len(self.files) - index*self.batch_size
        
    files_batch = self.files[index*self.batch_size : (index+1)*self.batch_size]

    image = []
    mask = []

    for files_name in files_batch:
      _img, _mask = self.__load__(files_name)
      image.append(_img)
      mask.append(_mask)

    image=np.array(image)
    mask=np.array(mask)
    return image, mask

  def on_epoch_end(self):
    pass

  def __len__(self):
    return int(np.ceil(len(self.files)/float(self.batch_size)))

#read through gen
batch_size = 8
epochs = 300
train_data= next(os.walk(input_img_path))[2]
train_data, valid_data = train_test_split(train_data, test_size=0.2, random_state=42)

#test generator
gen = DataGen(train_data, train_img_path, train_mask_path, batch_size=batch_size)
x, y = gen.__getitem__(0)
print(x.shape, y.shape)
#output = (8, 512, 512, 1) (8, 512, 512, 1)
#I also plot the image and mask to make sure data is ok

# train_gen and valid_gen
train_gen = DataGen(train_data, train_img_path, train_mask_path, batch_size=batch_size)
valid_gen = DataGen(valid_data, train_img_path, train_mask_path, batch_size=batch_size)

# build model (exactly the same as the above one)
def conv_block(input_tensor, num_filters):
    encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
    encoder = layers.BatchNormalization()(encoder)
    encoder = layers.Activation('relu')(encoder)
    encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
    encoder = layers.BatchNormalization()(encoder)
    encoder = layers.Activation('relu')(encoder)
    return encoder

def encoder_block(input_tensor, num_filters):
    encoder = conv_block(input_tensor, num_filters)
    encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
  
    return encoder_pool, encoder

def decoder_block(input_tensor, concat_tensor, num_filters):
    decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
    decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    return decoder

inputs = layers.Input(shape=(512,512,1))
encoder0_pool, encoder0 = encoder_block(inputs, 32)
encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64)
encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128)
encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256)
encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512)

center = conv_block(encoder4_pool, 1024)

decoder4 = decoder_block(center, encoder4, 512)
decoder3 = decoder_block(decoder4, encoder3, 256)
decoder2 = decoder_block(decoder3, encoder2, 128)
decoder1 = decoder_block(decoder2, encoder1, 64)
decoder0 = decoder_block(decoder1, encoder0, 32)
outputs = layers.Conv2D(2, (1, 1), activation='sigmoid')(decoder0)

model = models.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

#set callbacks
save_model_path = '/tmp/dir/tmp.hdf5'
cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_model_path, monitor='val_loss', mode='min', save_best_only=True)
def scheduler(epoch,lr):
  if epoch < 226:
    return lr
  else:
    return lr*tf.math.exp(-0.1)
lrs=tf.keras.callbacks.LearningRateScheduler(scheduler)

#fit
history = model.fit(train_gen, 
                   steps_per_epoch=int(np.ceil(len(train_data) / float(batch_size))),
                   epochs=epochs,
                   validation_data=valid_gen,
                   validation_steps=int(np.ceil(len(valid_data) / float(batch_size))),
                   callbacks=[cp,lrs])
model.save(output_path+'model.h5')

我使用相同的单波段图像运行脚本来测试生成器,但是训练精度保持不变,而训练损失和验证损失却在下降(见下图)。 Training curve with generator

发生器有什么问题吗?

=====更新=====

不知道原因,但似乎将我的数据作为tf.data.Dataset的类型馈送与将其作为批处理图像的类型从keras.Sequence生成器返回的馈送不同。

我对课程进行了如下修改:

class DataGen:
  def __init__(self, files, data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self, files_name):
    data_path = os.path.join(self.data_path, files_name)
    arr_img, arr_mask = load_patch(data_path)
    return arr_img, arr_mask

  def getitem(self, index):
    _img, _mask = self.__load__(self.files[index])
    return _img, _mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr, mask_arr = self.getitem(self.i)
      print('generate: ',self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr, mask_arr
  
  def __call__(self):
    self.i = 0
    return self

然后按如下所示应用tf.data.Dataset.from_generator并起作用。

train_gen = DataGen(train_data, train_img_path, train_mask_path)
valid_gen = DataGen(valid_data, train_img_path, train_mask_path)

train = tf.data.Dataset.from_generator(train_gen, (tf.float64,tf.float64),(tf.TensorShape([512,512,1]), tf.TensorShape([512,512,1])))
valid = tf.data.Dataset.from_generator(valid_gen, (tf.float64,tf.float64),(tf.TensorShape([512,512,1]), tf.TensorShape([512,512,1])))

train_dataset = train.cache().shuffle(len(train_data)).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = valid.batch(batch_size).repeat()

history = model.fit(train_dataset, 
                   steps_per_epoch=int(np.ceil(len(train_data) / float(batch_size))),
                   epochs=epochs,
                   validation_data=valid_dataset,
                   validation_steps=int(np.ceil(len(valid_data) / float(batch_size))),
                   callbacks=[cp,lrs])

0 个答案:

没有答案