我想从另一个回调(如EarlyStopping或ModelCheckpoint)中的回调中使用自定义指标。但是我需要以某种方式保存/存储/记录该自定义指标,以便其他回调可以访问此指标?
我有:
class Metrics(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.precision = []
self.f1s = []
self.prc=0
self.f1s=0
def on_epoch_end(self, epoch, logs={}):
score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))
predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))
targ = self.validation_data[2]
predict = (predict < 0.5).astype(np.float)
self.prc=sklm.precision_score(targ, predict)
self.f1s=sklm.f1_score(targ, predict)
self.precision.append(prc)
self.f1s.append(f1s)
print("— val_f1: %f — val_precision: %f" %(self.f1s, self.prc))
return
现在
metrics = Metrics()
es = EarlyStopping(monitor=metrics.prc, mode='max', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)
model.compile(loss=contrastive_loss, optimizer=adam)
model.fit([train_sen1, train_sen2], train_labels,
batch_size=512,
epochs=20,callbacks=[metrics,es],
validation_data=([dev_sen1, dev_sen2], dev_labels))
不起作用,因为Earlystopping不了解自定义精度指标?
有人知道这个回调的日志语句吗?我可以在那里保存指标吗?
答案 0 :(得分:0)
要了解这里到底发生了什么,您必须检查github上EarlyStopping和ModelCheckpoint类的源代码。您可以找到它here。
您的代码中的问题是您没有更新“ on_epoch_end”函数中的“ logs”字典。在该词典中,EarlyStopping和ModelCheckpoint类将查找您定义为“监视器”的内容。
因此,在您的情况下,如果要使用精度分数作为监视器,则代码应如下所示:
class Metrics(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.precision = []
self.f1scores = []
self.prc=0
self.f1s=0
def on_epoch_end(self, epoch, logs={}):
score = np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]]))
predict = np.round(np.asarray(self.model.predict([self.validation_data[0],self.validation_data[1]])))
targ = self.validation_data[2]
predict = (predict < 0.5).astype(np.float)
self.prc=sklm.precision_score(targ, predict)
self.f1s=sklm.f1_score(targ, predict)
self.precision.append(prc)
self.f1scores.append(f1s)
#Here is where I update the logs dictionary:
logs["prc"]=self.prc
logs["f1s"]=self.f1s
print("— val_f1: %f — val_precision: %f" %(self.f1s, self.prc))
然后您可以在CheckpointModel和EarlyStopping中调用这些自定义指标。但是,请确保将这些回调以正确的顺序放在fit_generator中:应首先放置指标,否则运行EarlyStopping时不会更新日志。
metrics = Metrics()
es = EarlyStopping(monitor="prc", mode='max', verbose=1, patience=3,min_delta=0.01,restore_best_weights=True)
model.compile(loss=contrastive_loss, optimizer=adam)
model.fit([train_sen1, train_sen2], train_labels,
batch_size=512,
epochs=20,callbacks=[metrics,es],
validation_data=([dev_sen1, dev_sen2], dev_labels))
答案 1 :(得分:0)
zealous_nightingale 的回答适用于 <NestedScrollView
xmlns:android="http://schemas.android.com/apk/res/android"
android:id="@+id/mainScrollView"
android:fillViewport="true"
android:layout_width="match_parent"
android:layout_height="match_parent">
<HorizontalScrollView
android:id="@+id/horizontalScroll"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
</NestedScrollView>
回调,但是,对于 EarlyStopping
回调,您可能还需要将 ModelCheckpoint
标志设置为 _supports_tf_logs
以便更新 False
字典传递给回调:
log