我最近开始在 Google Colab 笔记本中使用 Tensorflow 进行机器学习,致力于对食物图像进行分类的网络。
我的数据集正好包含 101,000 张图像和 101 个类别 - 每个类别 1000 张图像。 我在此之后开发的网络Tensorflow Blog
我开发的代码如下:
#image dimensions
batch_size = 32
img_height = 50
img_width = 50
#80% for training, 20% for validating
train_ds = image_dataset_from_directory(data_dir,
shuffle=True,
validation_split=0.2,
subset="training",
seed=123,
batch_size=batch_size,
image_size=(img_height, img_width)
)
val_ds = image_dataset_from_directory(data_dir,
shuffle=True,
validation_split=0.2,
subset="validation",
seed=123,
batch_size=batch_size,
image_size=(img_height, img_width)
)
#autotuning, configuring for performance
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
#data augmentation layer
data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.RandomFlip("horizontal",
input_shape=(img_height,
img_width,
3)),
layers.experimental.preprocessing.RandomRotation(0.1),
layers.experimental.preprocessing.RandomZoom(0.1),
]
)
#network definition
num_classes = 101
model = Sequential([
data_augmentation,
layers.experimental.preprocessing.Rescaling(1./255),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(256, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
经过 500 个 epoch 的训练后,准确率似乎增长得非常缓慢:
epoch 100: 2525/2525 - 19s 8ms/step - loss: 2.8151 - accuracy: 0.3144 - val_loss: 3.1659 - val_accuracy: 0.2549
epoch 500: 2525/2525 - 21s 8ms/step - loss: 2.7349 - accuracy: 0.0333 - val_loss: 3.1260 - val_accuracy: 0.2712
我试过了:
到目前为止,上面的代码提供了最好的结果,但我仍然想知道,
这种行为是预期的吗?这是拥有如此大数据集的结果吗?或者我的代码中是否存在任何可能阻碍学习过程的缺陷?
答案 0 :(得分:1)
在你的损失函数中删除 from_logits=True