我正在交叉验证一个模型,分为5个部分。 然后我针对每个拆分,损失和val_loss按时期进行绘制。
我得到类似的东西:
我发现这幅情节令人不安。
我如何交叉验证:
def cv(X, y, model, n_splits=5, epochs=5, batch_size=1024,
random_state=42, verbose=0):
# kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
kf = KFold(n_splits=n_splits, shuffle=False, random_state=random_state)
histories = []
for s in kf.split(X):
X_train = X.iloc[s[0]].to_numpy()
y_train = y.iloc[s[0]]['Target'].to_numpy()
X_test = X.iloc[s[1]].to_numpy()
y_test = y.iloc[s[1]]['Target'].to_numpy()
h = model.fit(X_train, y_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(X_test, y_test),
verbose=verbose)
histories.append(h)
return histories
模型:
def model_8(input_dim) -> tf.keras.models:
get_custom_objects().update({'swish': Activation(swish)})
inputs = Input(shape=(input_dim,))
x = Dense(200, activation='swish', name='hl_1')(inputs)
x = Dense(200, activation='swish', name='hl_2')(x)
x = Dense(200, activation='swish', name='hl_3')(x)
x = Dense(200, activation='swish', name='hl_4')(x)
x = Dense(200, activation='swish', name='hl_5')(x)
x = Dense(200, activation='swish', name='hl_6')(x)
x = Dense(200, activation='swish', name='hl_7')(x)
x = Dense(200, activation='swish', name='hl_8')(x)
x = Dense(200, activation='swish', name='hl_9')(x)
x = Dense(200, activation='swish', name='hl_10')(x)
x = Dense(200, activation='swish', name='hl_11')(x)
output = Dense(1, activation='sigmoid', name='output')(x)
model = Model(inputs=inputs, outputs=output)
model.compile(loss='mean_squared_error',
optimizer='adam')
#model._name = function.__name__
model._name = inspect.stack()[0][3]
return model
绘图功能:
def plots_(models_cv_histories, n_splits, save=False):
"""Plot all the learning curves for each trained models
Arguments:
models_cv_histories {array} -- array of histoires
"""
nb_models = len(models_cv_histories)
fig, axes = plt.subplots(nrows=nb_models,
ncols=n_splits, figsize=(12, 5))
row_index = 0
for cv_model in models_cv_histories:
hist = cv_model[1]
epochs = range(1, len(hist[0].epoch) + 1)
col_index = 0
for split_ in hist:
loss = split_.history['loss']
epochs = split_.epoch
val_loss = split_.history['val_loss']
model_name = split_.model.name
if nb_models > 1:
ax = axes[row_index][col_index]
else:
ax = axes[col_index]
ax.set_title(model_name + ' split ' + str(col_index))
ax.plot(epochs, loss, color="r", label="loss")
ax.plot(epochs, val_loss, color="g", label="val_loss")
ax.set_xlabel("epochs")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_ylabel("loss")
ax.legend(loc="upper right")
col_index += 1
row_index += 1
fig.subplots_adjust()
if save:
plt.savefig("plots/test.png")
fig.tight_layout()
plt.show()
您能给我一些提示吗?