训练多标签分类模型时内存不足

时间:2021-05-14 17:55:09

标签: python tensorflow

我是一名学习机器学习几个月的学生。

目标

我目前正在尝试创建一个多标签图像分类模型。

问题

随着学习的进行,会消耗大量的内存并出现错误。 尽管我只使用了大约 500MB 的数据(总共大约 50,000 张图像),但内存使用量超过了 32GB。 我怀疑我准备的数据集有问题。

我的尝试

我更改了批量大小并切换到使用数据生成器的实现。

  • 关于批量大小 → 学习到总步数的 1/10 左右后。但我认为即使我尝试将其设置为 1 也不会完成学习过程。

  • 关于数据生成器 → 我收到错误:[ 完成 GeneratorDataset 迭代器时发生错误:前提条件失败:Python 解释器状态未初始化。进程可能会终止。]

代码 ※这是一部分。

import ~~~

AUTOTUNE = tf.data.experimental.AUTOTUNE

def macro_f1(y, y_hat, thresh=0.5):
    y_pred = tf.cast(tf.greater(y_hat, thresh), tf.float32)
    tp = tf.cast(tf.math.count_nonzero(y_pred * y, axis=0), tf.float32)
    fp = tf.cast(tf.math.count_nonzero(y_pred * (1 - y), axis=0), tf.float32)
    fn = tf.cast(tf.math.count_nonzero((1 - y_pred) * y, axis=0), tf.float32)
    f1 = 2*tp / (2*tp + fn + fp + 1e-16)
    macro_f1 = tf.reduce_mean(f1)
    return macro_f1

def create_dataset(filenames, labels, is_training=True, total_cnt=0, batch_size=8, train_flag=False):
    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    if train_flag:
        dataset = dataset.map(parse_function, num_parallel_calls=AUTOTUNE)
    else:
        dataset = dataset.map(parse_function_val, num_parallel_calls=AUTOTUNE)
    dataset = dataset.cache()
    if train_flag:
        dataset = dataset.shuffle(buffer_size=total_cnt)
    dataset = dataset.repeat().batch(batch_size)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    return dataset

def parse_function(filename, label):
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image_resized = tf.image.resize(image_decoded, [224, 224])
    image_normalized = (image_resized / 255) -1
    image_aug = tf.image.random_flip_left_right(image=image_normalized)
    image_aug = tf.image.random_flip_up_down(image=image_aug)
    return image_aug, label

def parse_function_val(filename, label):
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    image_resized = tf.image.resize(image_decoded, [224, 224])
    image_normalized = (image_resized / 255) -1
    return image_normalized, label


def main():

    ~Abbreviation~

    # list for dataset
    a_train    = []
    b_train    = []
    c_train    = []
    image_list_train = []

    a_val   = []
    b_val   = []
    c_val   = []
    asphalt_val    = []
    image_list_val = []

    DATA_DIR = os.path.join(DIR, 'data')
    modes = ['train', 'val']
    classes = ['a', 'b', 'c', 'a_b', 'b_c', 'c_a', 'a_b_c'] #dir-name

    #make dataset
    for mode in modes:
        for class in classes:
            dir_path = os.path.join(DATA_DIR, mode, class)
            file_list = glob.glob(os.path.join(dir_path, '*'))
            file_list = sorted(file_list)

            for file_name in file_list:
                # Split the folder name (ex: a_b) by _ and 
                # label it according to whether it contains the letters a, b, c
                cls_list = cls_name.split('_')
                if 'a' in cls_list: a = 1
                else:               a = 0
                if 'b' in cls_list: b = 1
                else:               b = 0
                if 'c' in cls_list: c = 1
                else:               c = 0

                if mode == 'train':
                    a_train.append(a)
                    b_train.append(b)
                    c_train.append(c)
                    image_list_train.append(file_name)
                elif mode == 'val':
                    a_val.append(a)
                    b_val.append(b)
                    c_val.append(c)
                    image_list_val.append(file_name)
                else:
                    print('ERROR')

    num_train = len(image_list_train)
    num_val = len(image_list_val)

   # make label ex) [[0,0,1],[1,1,0],[0,1,0]]
    cls_id_train = []
    cls_id_val = []
    for i in range(num_train):
        cls_id_train.append([a_train[i], b_train[i], c_train[i]])
    for i in range(num_val):
        cls_id_val.append([a_val[i], b_val[i], c_val[i]])

    checkpoint_path = ckpt_dir + '/ckpt-{epoch}-{val_macro_f1:.2f}-{val_loss:.2f}'

    # make dataset
    train_ds = create_dataset(image_list_train,
                              cls_id_train,
                              total_cnt=num_train,
                              batch_size=batch_size,
                              train_flag=True,
                              )
    val_ds = create_dataset(image_list_val,
                            cls_id_val,
                            total_cnt=num_val,
                            batch_size=batch_size,
                            train_flag=False,
                            )

    # train
    IMG_SHAPE = (img_size, img_size, channels)
    base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                                include_top=False,
                                                weights='imagenet')
    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
    prediction_layer = tf.keras.layers.Dense(n_classes, activation='sigmoid')
    model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])
    model.compile(optimizer=optimizers.SGD(lr=lr, momentum=momentum, nesterov=nesterov),loss='binary_crossentropy', metrics=[macro_f1])

    ckpt_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 monitor='val_macro_f1',
                                                 mode='max',
                                                 verbose=1)

    csv_logger = tf.keras.callbacks.CSVLogger(ckpt_dir+'/training.csv', separator=',')

    history = model.fit(train_ds,
                        steps_per_epoch= int(num_train//batch_size),
                        validation_data=val_ds,
                        validation_steps= int(num_val//batch_size),
                        shuffle=True,
                        epochs=epochs,
                        callbacks=[ckpt_cb, csv_logger],)

    model.save_weights(ckpt_dir + '/my_checkpoint')`` 

1 个答案:

答案 0 :(得分:0)

在您的代码之前添加此代码片段,如果您使用的是 Tensorflow-GPU 版本,这将限制内存增长。

gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu[0], True)