Keras-基于用户输入的EarlyStopping

时间:2018-12-24 16:32:31

标签: python-3.x tensorflow keras

我想知道是否有一种简单的方法可以创建一种基于用户输入而不是监视任何特定指标来触发Keras提前停止的方法。

即,我想向执行训练的进程发送键盘信号,以使其脱离fit_generator函数并执行其余代码。

有什么想法吗?

编辑:基于@AnkurGoel的答案,我编写了以下代码:

# Monitors the SIGINT (ctrl + C) to safely stop training when it is sent
flag = False
class TerminateOnFlag(Callback):
    """Callback that terminates training when the flag is raised.
    """
    def on_batch_end(self, batch, logs=None):
        if flag:    
            self.model.stop_training = True

def handler(signum, frame):
    logging.info('SIGINT signal received. Training will finish after this epoch')
    global flag
    flag = True

signal.signal(signal.SIGINT, handler) # We assign a specific handler for the SIGINT signal
terminateOnFlag = TerminateOnFlag()
callbacks.append(terminateOnFlag)

callbacks是我输入到fit_generator的回调的列表。

在训练过程中,当我发送SIGINT信号时确实收到了消息SIGINT signal received. Training will finish after this epoch,但是当纪元结束时,什么也没有发生。发生了什么事?

2 个答案:

答案 0 :(得分:3)

您可以考虑以下方法:

使用一个全局变量,初始化0 使用信号处理程序,

当python进程接收到信号(中断)时,其值从0更改为1。

在Keras中使用自定义回调,当此变量值更改时停止训练

class TerminateOnFlag(Callback):
"""Callback that terminates training when flag=1 is encountered.
"""

def on_batch_end(self, batch, logs=None):
    if flag==1:    
        self.model.stop_training = True

可通过以下方式获得常规回调: https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L251

您仍然必须检查是否可以向fit_generator提供自定义回调,而不是标准回调。

这是信号处理程序的代码:

对于风行者:

import signal, os

def handler(signum, frame):
    print('Signal handler called with signal', signum)
    raise OSError("Couldn't open device!")

signal.signal(signal.CTRL_C_EVENT, handler) # only in python version 3.2

对于Linux:

import signal, os

def handler(signum, frame):
    print('Signal handler called with signal', signum)
    raise OSError("Couldn't open device!")

signal.signal(signal.SIGINT, handler) 

答案 1 :(得分:1)

更好和更安全的方法是将鼠标用作输入,停止和其他内部交互。

例如,当鼠标移到左侧(mouse_x <10)时,要在批处理结束时停止角豆:

def queryMousePosition():
    from ctypes import windll, Structure, c_long, byref
    class POINT(Structure): _fields_ = [("x", c_long), ("y", c_long)]
    pt = POINT()
    windll.user32.GetCursorPos(byref(pt))
    return pt.x, pt.y  # %timeit queryMousePosition()


class TerminateOnFlag(keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        mouse_x, mouse_y = queryMousePosition()
        if mouse_x < 10:
            self.model.stop_training = True

callbacks=[keras.callbacks.ReduceLROnPlateau(), TerminateOnFlag()]

model.fit_generator(..., callbacks=callbacks, ...)