凯拉斯早停下来。监视多头模型的自定义指标

时间:2019-11-23 00:47:37

标签: python machine-learning keras

我有一个带有4个头的模型。但是,为了尽早停止,我只想监视一个汇总了两个Head的指标。

不幸的是,我收到警告说:

以指标val_control_only_metric为条件的提前停止,该指标不可用。可用的指标有: 失利, 转向损耗 节流阀损耗 smoosh_loss, is_sim_loss, eering_control_only_metric, 油门_控制_仅_度量, smoosh_control_only_metric, is_sim_control_only_metric, val_loss, val_steering_loss, val_throttle_loss, val_smoosh_loss, val_is_sim_loss, val_steering_control_only_metric, val_throttle_control_only_metric, val_smoosh_control_only_metric, val_is_sim_control_only_metric

在验证集上运行时,我想监视control_only_metric的输出。这有可能吗?

def default_n_linear():
    input_shape = (120, 160, 3)
    roi_crop=(0, 0)
    drop = 0.1
    input_shape = adjust_input_shape(input_shape, roi_crop)

    img_in = Input(shape=input_shape, name='img_in')                                   # 0
    x = img_in
    x = Convolution2D(24, (5,5), strides=(2,2), activation='relu', name="conv2d_1")(x) # 1
    x = Dropout(drop)(x)                                                               # 2
    x = Convolution2D(32, (5,5), strides=(2,2), activation='relu', name="conv2d_2")(x) # 3
    x = Dropout(drop)(x)                                                               # 4
    x = Convolution2D(64, (5,5), strides=(2,2), activation='relu', name="conv2d_3")(x) # 5
    x = Dropout(drop)(x)                                                               # 6
    x = Convolution2D(64, (3,3), strides=(1,1), activation='relu', name="conv2d_4")(x) # 7
    x = Dropout(drop)(x)                                                               # 8
    x = Convolution2D(64, (3,3), strides=(1,1), activation='relu', name="conv2d_5")(x) # 9
    x = Dropout(drop)(x) 

    x = Flatten(name='flattened')(x)                                                   # 11
    x = Dense(100, activation='relu')(x)                                               # 12
    x = Dropout(drop)(x)                                                               # 13
    x = Dense(50, activation='relu')(x)                                                # 14
    x = Dropout(drop)(x)                                                               # 15

    steering   = Dense(1, activation="linear", name="steering")(x)
    throttle   = Dense(1, activation="linear", name="throttle")(x)
    smoosh     = Dense(1, activation="sigmoid", name="smoosh")(x)

    x_stop_grad = Lambda(lambda x: K.stop_gradient(x))(x)
    is_sim      = Dense(1, activation="sigmoid", name="is_sim")(x_stop_grad) # real=0, sim=1

    outputs = [steering, throttle, smoosh, is_sim]
    model = Model(inputs=[img_in], outputs=outputs)

    return model

def loss(y_true, y_pred):
    gt_angle      = y_true[0]
    gt_throttle   = y_true[1]
    gt_is_sim     = y_true[2]
    point_5       = y_true[3]

    pred_angle    = y_pred[0]
    pred_throttle = y_pred[1]
    pred_is_sim   = y_pred[2]
    pred_smoosh   = y_pred[3]

    loss  = K.mean(K.square(pred_angle-gt_angle))
    loss += K.mean(K.square(pred_throttle-gt_throttle))
    loss += K.mean(K.square(pred_is_sim-gt_is_sim))
    loss += K.mean(K.square(pred_smoosh - point_5))
    return loss

def control_only_metric(y_true, y_pred):

    gt_angle      = y_true[0]
    gt_throttle   = y_true[1]

    pred_angle    = y_pred[0]
    pred_throttle = y_pred[1]

    return K.mean(K.square(pred_angle-gt_angle)) +\
           K.mean(K.square(pred_throttle-gt_throttle))


model = default_n_linear()
model.compile(optimizer=self.optimizer, loss=loss, metrics=[control_only_metric])

early_stop = keras.callbacks.EarlyStopping(monitor="val_control_only_metric",
                                           min_delta=0.05,
                                           patience=5
                                           mode='min')

kl.model.fit_generator(train_gen,
                       steps_per_epoch=steps_per_epoch,
                       epochs=epochs,
                       validation_data=val_gen,
                       callbacks=[early_stop],
                       validation_steps=val_steps,
                       workers=workers_count,
                       use_multiprocessing=use_multiprocessing)

0 个答案:

没有答案