我在Keras中建立了一个模型来检测猫的关键点。对于每个图像,我都有3个关键点以及三个对应的热图。我将3个热图堆叠在一起,得到具有3个通道的单个图像。 我的模型的输入大小为64,64,3,输出图像大小为64,64,3。
我为图像和热图创建了2个ImageDataGenerator,并将它们压缩在一起。 我有30个纪元,批量是32个。 拟合模型时,它不会脱离训练单元!
图像和热图生成器如下所示:
from sklearn.model_selection import train_test_split
x_train, x_test = train_test_split(dataset['cropped_imgs'],test_size=0.20)
y_train, y_test = train_test_split(dataset['cropped_heatmaps'],test_size=0.20)
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(featurewise_center=False,
featurewise_std_normalization=False,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.2,
validation_split=0.2,
)
img_train_generator = datagen.flow(tf.convert_to_tensor(x_train) ,
batch_size=32,
shuffle = True,
seed = 1,
subset='training'
)
img_validation_generator = datagen.flow(tf.convert_to_tensor(x_train) ,
batch_size=32,
shuffle = True,
seed = 1,
subset='validation'
)
img_test_generator = datagen.flow(tf.convert_to_tensor(x_test) ,
batch_size=128,
shuffle = True,
seed = 1,
)
heatmapgen = ImageDataGenerator(featurewise_center=False,
featurewise_std_normalization=False,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.2,
validation_split=0.2)
heatmaps_train_generator = heatmapgen.flow(tf.convert_to_tensor(y_train) ,
batch_size=32,
shuffle = True,
seed = 1,
subset='training'
)
heatmaps_validation_generator = heatmapgen.flow(tf.convert_to_tensor(y_train) ,
batch_size=32,
shuffle = True,
seed = 1,
subset='validation'
)
img_heatmaps_test_generator = heatmapgen.flow(tf.convert_to_tensor(y_test) ,
batch_size=32,
shuffle = True,
seed = 1,
)
模型拟合:
model.compile(loss='mse', optimizer = opt,
metrics=['accuracy'])
model.compile(loss='mse', optimizer = opt,
metrics=['accuracy'])
train_generator = zip(img_train_generator, heatmaps_train_generator)
history = model.fit((pair for pair in train_generator),
epochs=30,
validation_data = (img_validation_generator,heatmaps_validation_generator)
)
训练1小时后的唯一输出是
Epoch 1/30
66/Unknown - 2305s 35s/step - loss: 0.0455 - accuracy: 0.3345
我尝试使用TPU运行该模型,但这似乎不是性能问题。该数据集包含1700张图像,数量不多! 知道为什么它会卡在试管中吗?
我们非常感谢您的帮助。