我正在尝试对 Cats and Dogs 数据集进行迁移学习,但问题是我已经阅读了许多文档,但我的损失并没有减少......有人可以帮忙吗?有一些根本性的错误。
我的损失保持在 50%,这是随机猜测。该模型根本就不是训练。我什至尝试更改 base_model.trainiable = True,仍然没有。
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
for filename in filenames:
print(os.path.join(dirname, filename))
import random
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import zipfile
import shutil
def extract_files(source, target):
unzip_source = zipfile.ZipFile(source, 'r')
unzip_source.extractall(target)
unzip_source.close()
os.listdir('/kaggle/input/')
extract_files('/kaggle/input/dogs-vs-cats/test1.zip', '/kaggle/working/')
extract_files('/kaggle/input/dogs-vs-cats/train.zip', '/kaggle/working/')
base_dir = ('/kaggle/working/')
base_dir
os.listdir('/kaggle/working/train')[:10]
print(len(os.listdir('/kaggle/working/train')))
os.makedirs('/kaggle/working/training/dogs')
os.makedirs('/kaggle/working/training/cats')
target_path = '/kaggle/working/training/'
os.listdir(target_path)
def copy_file(file_path, file_name, target):
source_path = os.path.join(file_path, file_name)
final_path = os.path.join(target_path, target)
shutil.copy(source_path, final_path)
def separate_files(file_path):
for img in os.listdir(file_path):
if img.split('.')[0] == 'cat':
copy_file(file_path, img, os.path.join(target_path, 'cats'))
else:
copy_file(file_path, img, os.path.join(target_path, 'dogs'))
separate_files(os.path.join(base_dir, 'train'))
os.listdir(os.path.join(base_dir,'training','cats'))[:20]
print(len(os.listdir(os.path.join(base_dir,'training','cats'))))
print(len(os.listdir(os.path.join(base_dir,'training','dogs'))))
cat_folder = os.path.join(os.path.join(base_dir,'training','cats'))
dog_folder = os.path.join(os.path.join(base_dir,'training','dogs'))
data_dir = os.path.join(base_dir, 'training')
plt.figure(figsize=(10,10))
for i in range(15):
sub_dir_name = np.random.choice(os.listdir(data_dir))
sub_dir = os.path.join(data_dir, sub_dir_name)
img_file = os.path.join(sub_dir, np.random.choice([x for x in os.listdir(sub_dir)] ))
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
img = plt.imread(img_file)
plt.imshow(img)
plt.xlabel(f'dir: {sub_dir_name}')
plt.show()
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_gen = ImageDataGenerator(rescale = 1.0/255.0,
shear_range= 0.2,
zoom_range= 0.3,
width_shift_range = 0.2,
height_shift_range = 0.2,
horizontal_flip= True,
vertical_flip= True,
rotation_range= 45,
validation_split= 0.2,
)
valid_gen = ImageDataGenerator(rescale = 1.0/255.0,
shear_range= 0.2,
zoom_range= 0.3,
horizontal_flip= True,
vertical_flip= True,
rotation_range= 45,
validation_split= 0.2,
)
train_ds = train_gen.flow_from_directory(data_dir,
target_size= (160,160),
shuffle= True,
batch_size = 128,
seed= 7,
subset= 'training'
)
valid_ds = valid_gen.flow_from_directory(data_dir,
target_size= (160,160),
shuffle= True,
batch_size = 128,
seed= 7,
subset= 'validation'
)
# from keras.applications.inception_resnet_v2 import InceptionResNetV2
# from keras.applications.inception_resnet_v2 import preprocess_input
# from keras.applications.inception_resnet_v2 import decode_predictions
# from sklearn.metrics import accuracy_score
tf.keras.backend.clear_session()
base_model = tf.keras.applications.MobileNetV2(input_shape = (160,160,3),
weights= 'imagenet',
include_top = False)
base_model.trainable = False
base_model.summary()
global_avg_layer = tf.keras.layers.GlobalAveragePooling2D()
model = tf.keras.Sequential([base_model,
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation = 'relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1, activation = 'sigmoid')])
model.summary()
model.compile(optimizer= tf.keras.optimizers.Adam(learning_rate= 0.0001),
loss = tf.keras.losses.BinaryCrossentropy(from_logits= True),
metrics = ['accuracy'])
early_stop = tf.keras.callbacks.EarlyStopping(patience = 10)
reduce_on_plateau = tf.keras.callbacks.ReduceLROnPlateau(patience = 5)
history = model.fit(train_ds, epochs=200, validation_data= valid_ds, callbacks= [early_stop, reduce_on_plateau])