我下面有TF 2.30代码。该模型具有图像超分辨率。作为输入,我有两个训练数据集,分别用于验证。我想对这些数据集使用model.fit。
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Add, Activation
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers
import pathlib
def build_model():
input_img = Input(shape=(48, 48, 1))
model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(input_img)
model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(model)
model = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', activation='relu')(model)
model = Conv2D(1, (3, 3), padding='same', kernel_initializer='he_normal')(model)
res_img = model
output_img = Add()([res_img, input_img])
model = Model(inputs=input_img, outputs=output_img)
return model
def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.io.decode_png(image, channels=1)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image -= 0.5
image /= 0.5
return image
def configure_for_performance(ds, batch_size):
ds = ds.cache()
ds = ds.shuffle(buffer_size=1000)
ds = ds.batch(batch_size)
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
if __name__ == '__main__':
config = tf.compat.v1.ConfigProto(gpu_options= tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.8))
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)
print('start training...')
BATCH_SIZE = 64
model = build_model()
adam = optimizers.Adam(lr=1e-2)
model.compile(adam, loss='mse')
model.summary()
print('start training....')
data_orig = tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\train\\orig\\*.png'), shuffle=False)
data_pred = tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\train\\pred\\*.png'), shuffle=False)
valid_orig = tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\valid\\orig\\*.png'), shuffle=False)
valid_pred = tf.data.Dataset.list_files(str('C:\\SRColor2\\data\\div2k\\valid\\pred\\*.png'), shuffle=False)
AUTOTUNE = tf.data.experimental.AUTOTUNE
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
data_orig = data_orig.map(load_image, num_parallel_calls=AUTOTUNE)
data_pred = data_pred.map(load_image, num_parallel_calls=AUTOTUNE)
valid_orig = valid_orig.map(load_image, num_parallel_calls=AUTOTUNE)
valid_pred = valid_pred.map(load_image, num_parallel_calls=AUTOTUNE)
data_orig = configure_for_performance(data_orig, 64)
data_pred = configure_for_performance(data_pred, 64)
valid_orig = configure_for_performance(valid_orig, 64)
valid_pred = configure_for_performance(valid_pred, 64)
model.fit((data_pred, data_orig),
epochs=40,
batch_size=64,
validation_data=(valid_pred, valid_orig))
print('end training')
print('training ended')
运行代码时,出现以下错误: ValueError:使用数据集作为输入时,不支持y
参数。
答案 0 :(得分:0)
错误消息说明了一切。您可以尝试压缩数据集:
.compA__body {
@extend .testt;
}