使用model.fit和数据集作为超分辨率模型的输入时出错

时间:2020-08-24 19:13:50

标签: python tensorflow

我下面有TF 2.30代码。该模型具有图像超分辨率。作为输入,我有两个训练数据集,分别用于验证。我想对这些数据集使用model.fit。

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Add, Activation
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers
import pathlib


def build_model():
    input_img = Input(shape=(48, 48, 1))
    model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(input_img)

    model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(model)
    model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(model)
    model = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal')(model)
    res_img = model

    output_img = Add()([res_img, input_img])
    model = Model(inputs=input_img, outputs=output_img)
    return model


def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image, channels=1)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image -= 0.5
    image /= 0.5

    return image

def configure_for_performance(ds, batch_size):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds


if __name__ == '__main__':

    config = tf.compat.v1.ConfigProto(gpu_options= tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.8))
    config.gpu_options.allow_growth = True
    session = tf.compat.v1.Session(config=config)
    tf.compat.v1.keras.backend.set_session(session)
    print('start training...')
    BATCH_SIZE = 64

    model = build_model()
    adam = optimizers.Adam(lr=1e-2)
    model.compile(adam, loss='mse')
    model.summary()

    print('start training....')

    data_orig =  tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\train\\orig\\*.png'), shuffle=False)
    data_pred = tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\train\\pred\\*.png'), shuffle=False)
    valid_orig = tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\valid\\orig\\*.png'), shuffle=False)
    valid_pred = tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\valid\\pred\\*.png'), shuffle=False)

    AUTOTUNE = tf.data.experimental.AUTOTUNE

    # Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
    data_orig = data_orig.map(load_image, num_parallel_calls=AUTOTUNE)
    data_pred = data_pred.map(load_image, num_parallel_calls=AUTOTUNE)
    valid_orig = valid_orig.map(load_image, num_parallel_calls=AUTOTUNE)
    valid_pred = valid_pred.map(load_image, num_parallel_calls=AUTOTUNE)

    data_orig = configure_for_performance(data_orig, 64)
    data_pred = configure_for_performance(data_pred, 64)
    valid_orig = configure_for_performance(valid_orig, 64)
    valid_pred = configure_for_performance(valid_pred, 64)

    model.fit((data_pred, data_orig),
              epochs=40,
              batch_size=64,
              validation_data=(valid_pred, valid_orig))
    print('end training')

    print('training ended')

运行代码时,出现以下错误: ValueError:使用数据集作为输入时,不支持y参数。

  1. 我必须运行哪种解决方案的代码?
  2. 就我而言,我为每个图像生成了尺寸为48x48像素的色块。图像具有高分辨率。如果要生成分辨率为(X,Y)的图像 (X // 48)x(Y // 8)个大小为$ 48x48 $像素的补丁,那么数据集大小可以动态增加吗?

1 个答案:

答案 0 :(得分:0)

错误消息说明了一切。您可以尝试压缩数据集:

.compA__body {
    @extend .testt;
}