使用大型数据集进行TensorFlow训练花费的时间太长

时间:2020-01-23 06:09:19

标签: tensorflow tensorflow2.0

昨天,我创建了一个带有自定义头部的预训练VGG19,并试图用60000张图像对其进行训练。 12个多小时后,第一个纪元的训练未完成。 批处理大小已设置为64,每个纪元的步数已设置为training_set_size / batch_size。

下面是DataLoader的代码:

    IMAGE_CHANNEL = 3

def crop(image, margin):
    return image[margin:-margin, margin:-margin]


def random_rotation(image, angle):
    M = cv2.getRotationMatrix2D((0, 0),angle,1)
    rows,cols, _ = image.shape
    new_img = cv2.warpAffine(image, M, (cols, rows))
    return new_img


def get_generator(in_gen, should_augment=True):
    weights = None
    if should_augment:
        image_gen = tf.keras.preprocessing.image.ImageDataGenerator(fill_mode='reflect',
                                       data_format='channels_last',
                                       brightness_range=[0.5, 1.5])
    else:
        image_gen = tf.keras.preprocessing.image.ImageDataGenerator(fill_mode='reflect',
                                       data_format='channels_last',
                                       brightness_range=[1, 1])
    for items in in_gen:
        in_x, in_y = items
        g_x = image_gen.flow(255 * in_x, in_y, batch_size=in_x.shape[0])
        x, y = next(g_x)
        yield x / 255.0, y


class DataLoader:
    def __init__(self, source_filename, dataset_path, image_size, batch_size, training_set_size=0.8, sample_size=None):
        path_dataset = Path(dataset_path)
        path_image_folders = path_dataset / 'images'
        self.data = pd.read_pickle(source_filename)
        if sample_size is not None:
            self.data = self.data[:sample_size]
        self.image_size = image_size
        self.batch_size = batch_size
        self.training_set_size = training_set_size
        self.steps_per_epoch = int(self.data.shape[0] * training_set_size // batch_size)
        if self.steps_per_epoch == 0: self.steps_per_epoch = 1
        self.validation_steps = int(self.data.shape[0] * (1 - training_set_size)//batch_size)
        if self.validation_steps == 0: self.validation_steps = 1

    def draw_idx(self, i):
        img_path = self.data.iloc[i].image
        img = tf.keras.preprocessing.image.img_to_array(tf.keras.preprocessing.image.load_img(str(img_path)))
        # print(img.shape)
        height, width, _ = img.shape
        fig = plt.figure(figsize=(15, 15), facecolor='w')
        # original image
        ax = fig.add_subplot(1, 1, 1)
        ax.imshow(img / 255.0)
        openness = self.data.iloc[i].Openness
        conscientiousness = self.data.iloc[i].Conscientiousness
        extraversion = self.data.iloc[i].Extraversion
        agreeableness = self.data.iloc[i].Agreeableness
        neuroticism = self.data.iloc[i].Neuroticism
        ax.title.set_text(
            f'O: {openness}, C: {conscientiousness}, E: {extraversion}, A: {agreeableness}, N: {neuroticism}')
        plt.axis('off')
        plt.tight_layout()
        plt.show()

    def get_image(self, index, data, should_augment):
        # Read image and appropiate landmarks
        image = cv2.imread(data['image'].values[index])
        h, w, _ = image.shape
        o, c, e, a, n = data[['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Neuroticism']].values[
            index]

        should_flip = random.randint(0, 1)
        should_rotate = random.randint(0, 1)
        should_crop = random.randint(0, 1)
        if should_augment:
            if should_flip == 1:
                # print("Image {} flipped".format(data['path'].values[index]))
                image = cv2.flip(image, 1)
            if should_rotate == 1:
                angle = random.randint(-5, 5)
                image = random_rotation(image, angle)
            if should_crop == 1:
                margin = random.randint(1, 10)
                image = crop(image, margin)
        image = cv2.resize(image, (self.image_size, self.image_size))
        return [image, o, c, e, a, n]

    def generator(self, data, should_augment=True):
        while True:
            # Randomize the indices to make an array
            indices_arr = np.random.permutation(data.count()[0])
            for batch in range(0, len(indices_arr), self.batch_size):
                # slice out the current batch according to batch-size
                current_batch = indices_arr[batch:(batch + self.batch_size)]
                # initializing the arrays, x_train and y_train
                x_train = np.empty(
                    [0, self.image_size, self.image_size, IMAGE_CHANNEL], dtype=np.float32)
                y_train = np.empty([0, 5], dtype=np.int32)
                for i in current_batch:
                    # get an image and its corresponding color for an traffic light
                    [image, o, c, e, a, n] = self.get_image(i, data, should_augment)
                    # Appending them to existing batch
                    x_train = np.append(x_train, [image], axis=0)
                    y_train = np.append(y_train, [[o, c, e, a, n]], axis=0)
                    # replace nan values with zeros
                    y_train = np.nan_to_num(y_train)

                yield (x_train, y_train)

    def get_training_and_test_generators(self, should_augment_training=True, should_augment_test=True):
        msk = np.random.rand(len(self.data)) < self.training_set_size
        train = self.data[msk]
        test = self.data[~msk]
        train_gen = self.generator(train, should_augment_training)
        test_gen = self.generator(test, should_augment_test)
        return get_generator(train_gen, should_augment_training), get_generator(test_gen, should_augment_test)


    def show_batch_images_sample(self, images, landmarks, n_rows=3, n_cols=3):
        assert n_rows * n_cols <= self.batch_size, "Number of expected images to display is larger than batch!"
        fig = plt.figure(figsize=(15, 15))
        xs, ys = [], []
        count = 1
        for img, y in zip(images, landmarks):
            ax = fig.add_subplot(n_rows, n_cols, count)
            ax.imshow(img)
            h, w, _ = img.shape
            o, c, e, a, n = y
            ax.title.set_text(f'{o}, {c}, {e}, {a}, {n}')
            ax.axis('off')
            if count == n_rows * n_cols:
                break
            count += 1


class CallbackTensorboardImageOutput(Callback):
    def __init__(self, model, generator, log_dir, feed_inputs_display=9):
        # assert ((feed_inputs_display & (feed_inputs_display - 1)) == 0) and feed_inputs_display != 0
        self.generator = generator
        self.model = model
        self.log_dir = log_dir
        self.writer = tf.summary.create_file_writer(self.log_dir)
        self.feed_inputs_display = feed_inputs_display
        self.seen = 0

    def plot_to_image(figure):
        """Converts the matplotlib plot specified by 'figure' to a PNG image and
        returns it. The supplied figure is closed and inaccessible after this call."""
        # Save the plot to a PNG in memory.
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        # Closing the figure prevents it from being displayed directly inside
        # the notebook.
        plt.close(figure)
        buf.seek(0)
        # Convert PNG buffer to TF image
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        # Add the batch dimension
        image = tf.expand_dims(image, 0)
        return image

    @staticmethod
    def get_loss(gt, predictions):
        return tf.losses.mse(gt, predictions)

    def on_epoch_end(self, epoch, logs={}):
        self.seen += 1
        if self.seen % 1 == 0:
            items = next(self.generator)
            images_to_display = self.feed_inputs_display
            images_per_cell_count = int(math.sqrt(images_to_display))
            # in case of regular model training using generator, an array is passed
            if not isinstance(items, dict):
                frames_arr, ocean_scores = items
                # Take just 1st sample from batch
                batch_size = frames_arr.shape[0]

                if images_to_display > batch_size:
                    images_to_display = batch_size
                frames_arr = frames_arr[0:images_to_display]
                ocean_scores = ocean_scores[0:images_to_display]

                y_pred = self.model.predict(frames_arr)
            # in case of adversarial training, a dictionary is passed
            else:
                batch_size = items['feature'].shape[0]
                if images_to_display > batch_size:
                    images_to_display = batch_size
                # items['feature'] = items['feature'][0:images_to_display]
                # landmarks = items['label'][0:images_to_display]
                frames_arr = items['feature']
                landmarks = items['label']
                y_pred = self.model.predict(items)

            figure = plt.figure(figsize=(15, 15))
            for i in range(images_to_display):
                image_current = frames_arr[i]
                y_prediction_current = y_pred[i]
                y_gt_current = ocean_scores[i]
                lbl_prediction = 'plot/img/{}'.format(i)
                ax = plt.subplot(images_per_cell_count, images_per_cell_count, i + 1, title=lbl_prediction)
                ax.imshow(image_current)
                ax.axis('off')

            with self.writer.as_default():
                tf.summary.image("Training Data", CallbackTensorboardImageOutput.plot_to_image(figure), step=self.seen)

下面是网络体系结构的定义以及fit_generator函数的调用:

data_loader = dataloader.DataLoader('dataset.pkl', '/home/niko/data/PsychoFlickr', 224, 64)
train_gen, test_gen = data_loader.get_training_and_test_generators()

pre_trained_model = tf.keras.applications.VGG19(input_shape=(data_loader.image_size, data_loader.image_size, dataloader.IMAGE_CHANNEL), weights='imagenet', include_top=False)
x = pre_trained_model.output
x = tf.keras.layers.Flatten()(x)

# Add a fully connected layer with 256 hidden units and ReLU activation
x = tf.keras.layers.Dense(256)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(rate=0.5)(x)

x = tf.keras.layers.Dense(256)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(rate=0.5)(x)
x = tf.keras.layers.Dense(5, name='regresion_output')(x)
x = tf.keras.layers.Activation('linear')(x)

model = tf.keras.Model(pre_trained_model.input, x)
print(model.summary())
log_dir = "logs/{}".format(model_name)
model_filename = "saved-models/{}.h5".format(model_name)
cb_tensorboard = TensorBoard(log_dir=log_dir)
callback_save_images = dataloader.CallbackTensorboardImageOutput(model, test_gen, log_dir)
checkpoint = ModelCheckpoint(model_filename, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
lr = 1e-3
opt = tf.optimizers.Adam(lr=lr)
model.compile(loss=loss_mse, optimizer=opt, metrics=[loss_mse])
history = model.fit_generator(
                train_gen,
                validation_data=test_gen,
                steps_per_epoch=data_loader.steps_per_epoch,
                epochs=20,
                validation_steps=data_loader.validation_steps,
                verbose=2,
                use_multiprocessing=True,
                callbacks=[checkpoint, callback_save_images, cb_tensorboard]
            )

当我尝试使用少量样本数据(200条记录)运行相同的过程时,一切似乎都正常。但是,在6万条记录的数据集上,经过12个多小时后,第1个纪元的训练尚未完成。

培训是在NVIDIA RTX2080Ti上进行的。

如果有人建议必须修改或进行一般配置以在合理的时间训练网络,我将不胜感激。

0 个答案:

没有答案
相关问题