这是输入管道的代码。它将图像的大小调整为输入(224,224,3)和输出(224,224,2)。
image_path_list = glob.glob('/content/drive/My Drive/datasets/imagenette/*')
data = tf.data.Dataset.list_files(image_path_list)
def tf_rgb2lab(image):
im_shape = image.shape
[image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
image.set_shape(im_shape)
return image
def preprocess(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf_rgb2lab(image)
L = image[:,:,0]/100.
ab = image[:,:,1:]/128.
input = tf.stack([L,L,L], axis=2)
return input, ab
train_ds = data.map(preprocess, tf.data.experimental.AUTOTUNE).batch(64).repeat()
train_ds = data.prefetch(tf.data.experimental.AUTOTUNE)
以下是该模型的代码。 我认为该模型没有任何问题,因为当我在图像上调用model.predict()时,它可以工作。 因此,我假设输入管道出了点问题,但自从我第一次使用tf.data以来,我一直无法弄清楚到底是什么。
vggmodel = tf.keras.applications.VGG16(include_top=False, weights='imagenet')
model = tf.keras.Sequential()
for i,layer in enumerate(vggmodel.layers):
model.add(layer)
for layer in model.layers:
layer.trainable=False
model.add(tf.keras.layers.Conv2D(256, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
model.add(tf.keras.layers.Conv2D(8, (3,3), padding='same', activation='relu'))
model.add(tf.keras.layers.Conv2D(2, (3,3), padding='same', activation='tanh'))
model.add(tf.keras.layers.UpSampling2D((2,2)))
无论如何我打印(train_ds)都会得到:
<PrefetchDataset shapes: (), types: tf.string>
我尝试了以下代码:
path = next(iter(train_ds))
L,ab = preprocess(path)
L.shape
我知道了
TensorShape([224, 224, 3])
这意味着它正在返回3维张量。 那为什么在我打电话时出现错误?
model.fit(train_ds, epochs=1, steps_per_epoch=steps, callbacks=[model_checkpoint_callback, early_stopping_callback])
答案 0 :(得分:0)
layer.trainable = False和model.fit相反。 前者说只将模型设置为推理并关闭反向传播,而model.fit用于训练。 也许您在寻找model.predict?
答案 1 :(得分:0)
是的,花了一些时间,但我知道了。这是一个非常愚蠢的错误。
train_ds = data.map(preprocess, tf.data.experimental.AUTOTUNE).batch(64).repeat()
train_ds = data.prefetch(tf.data.experimental.AUTOTUNE)
实际上应该是:
train_ds = data.map(preprocess, tf.data.experimental.AUTOTUNE).batch(64).repeat().prefetch(tf.data.experimental.AUTOTUNE)