具有Keras模型的PyQt5 GUI在单独的线程中运行,再次按下“运行”按钮时卡住

时间:2019-03-08 14:58:33

标签: matplotlib keras pyqt5 qthread

我有一个应用程序,该应用程序接收带有标签火车数据的.pickle文件,并应建立一个神经网络(使用Keras)。它应该训练数据,并使用matplotlib在画布上实时显示训练/验证错误,并使用QprogressBar显示进度。

我有一个自定义的回调,该回调将pyqtSignal发送到每个时期的主GUI,发送当前时期以及累积的训练和验证错误。然后,在主程序中有一个函数可以接收信号并触发更新方法。

一切正常,直到我按GUI窗口-然后应用程序卡住了(但是网络仍然在外壳中运行)。我猜想点击中断会触发某个循环,使整个程序冻结,但我无法弄清楚哪个。

我已经搜索了其他有关使用线程时PyQt5 GUI卡住的问题,但没有找到答案-here

我尝试为Qthread使用Qthread.start()而不是Qthread.run()-但在这种情况下,绘图根本不会更新。

我已经写了一个完整的示例来演示该问题(数据文件应为.pickle格式,并包含X的列表[X,y]-样本为numpy ndarray,而y-相应的标签为numpy ndarray,可以在https://www.kaggle.com/luciferadmin/heart-disease-uci-in-pickle-format中找到):

import sys
import os
import pickle as pkl
from keras.layers import Input, Dense
from keras.callbacks import Callback
from keras.models import Model
from PyQt5.QtWidgets import QApplication, QProgressBar, QWidget, QVBoxLayout, QPushButton, QLineEdit, QFileDialog
from PyQt5.QtCore import QThread, pyqtSignal
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt5.QtWidgets import (QSizePolicy)
matplotlib.use('Qt5Agg')


class Plot(FigureCanvas):
    def __init__(self, x_label, y_label, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)

        self.axes = fig.add_subplot(111)

        self.compute_initial_figure()
        self.axes.set_xlabel(x_label)
        self.axes.set_ylabel(y_label)

        FigureCanvas.__init__(self, fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)

        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        self.axes.set_xticks(range(1, 100, 10))


class MultiPlot(Plot):
    def __init__(self, parent=None, x_axis_name='X', y_axis_name='Y', width=5, height=4, dpi=100):
        super().__init__(x_axis_name, y_axis_name, parent, width, height, dpi)
        self.compute_initial_figure()

    def compute_initial_figure(self):
        self.axes.set_xticks(range(0, 100, 10))

    def plot_multi_data(self, x_axis_name='X', y_axis_name='Y', plot_labels=None, y_list=None):
        if y_list is not None:
            self.axes.clear()

            graph_handles = []

            markers = ['b:', 'r']
            y_index = 0

            for y in y_list:

                x = range(1, len(y) + 1)
                label = plot_labels[y_index]

                new_plot, = self.axes.plot(x, y, markers[y_index], markersize=2, label=label)
                graph_handles.append(new_plot)
                y_index += 1

                self.axes.set_xticks(x, int(len(list(x))/10))
            self.axes.legend(handles=graph_handles, loc=0, fontsize=8, shadow=True)

        self.axes.set_xlabel(x_axis_name)
        self.axes.set_ylabel(y_axis_name)

        self.draw()


class TrainPlotCallback(Callback):
    def __init__(self, signal):
        Callback.__init__(self)
        self.train_err = []
        self.val_err = []
        self.signal = signal

    def on_epoch_end(self, epoch, logs={}):
        self.train_err.append(1 - logs.get('acc'))
        self.val_err.append(1 - logs.get('val_acc'))

        self.signal.emit(epoch, [self.train_err, self.val_err])


def classification_model(data_input_path, on_epoch_end_signal):

    # ///////////////////// TEST /////////////////////
    if os.path.exists(data_input_path):
        plot_losses = TrainPlotCallback(on_epoch_end_signal)
        with open(data_input_path, 'rb') as pickle_in:
            data = pkl.load(pickle_in)
            X = data[0]
            y = data[1]

        input_size = X.shape[1]

        # MODEL CREATION
        # ///////////////////// INPUT LAYER /////////////////////
        inputs = Input(shape=(input_size,))
        # ///////////////////// INPUT LAYER /////////////////////
        # ///////////////////// HIDDEN LAYER /////////////////////
        x = Dense(10, activation='relu', kernel_initializer='normal')(inputs)   # THE FIRST LAYER
        # ///////////////////// HIDDEN LAYER /////////////////////
        # ///////////////////// OUTPUT LAYERS /////////////////////
        predictions = Dense(len(y[0]), activation='softmax')(x)  # the length of the output layer is as the length of the classes being predicted.
        # ///////////////////// OUTPUT LAYERS /////////////////////
        # MODEL CREATION

        # ///////////////////// MODEL DEFINITION /////////////////////
        model = Model(inputs=inputs, outputs=predictions)
        model.compile(optimizer='Adam',
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        # ///////////////////// MODEL DEFINITION /////////////////////

        # ///////////////////// MODEL TRAINING /////////////////////
        model.fit(X, y, validation_split=0.2, batch_size=100, epochs=100, callbacks=[plot_losses])
        # ///////////////////// MODEL TRAINING /////////////////////


class ModelThread(QThread):
    epoch_end_signal = pyqtSignal(int, list)  # signal that has epoch # as the first parameter, and a list that contains the error values for the train and validation.

    def __init__(self, data_input_path):
        QThread.__init__(self)
        self.data_input_path = data_input_path

    def __del__(self):
        self.wait()

    def run(self):
        classification_model(data_input_path=self.data_input_path,
                             on_epoch_end_signal=self.epoch_end_signal
                             )


class DashBoard(QWidget):
    def __init__(self):
        super().__init__()
        self.main_v_box = QVBoxLayout(self)

        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< STRINGS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.input_data_path_str = ''
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< STRINGS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< PROGRESS BAR >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.progress_bar = QProgressBar()
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< PROGRESS BAR >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< BUTTONS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.run_model_btn = QPushButton('Run')
        self.browse_train_data_file_path_btn = QPushButton('Browse')
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< BUTTONS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< MULTI PLOTS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.in_training_plot = MultiPlot(x_axis_name='Epoch Number', y_axis_name='Error')
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< MULTI PLOTS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< LINE EDITS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.train_data_file_path_le = QLineEdit()
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< LINE EDITS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.init()
        self.pack()
        self.showMaximized()

    def init(self):
        self.progress_bar.hide()
        self.browse_train_data_file_path_btn.clicked.connect(self.on_btn_click)
        self.run_model_btn.clicked.connect(self.on_btn_click)

    def pack(self):
        self.main_v_box.addWidget(self.train_data_file_path_le)
        self.main_v_box.addWidget(self.browse_train_data_file_path_btn)
        self.main_v_box.addWidget(self.in_training_plot)
        self.main_v_box.addWidget(self.run_model_btn)
        self.main_v_box.addWidget(self.progress_bar)

    def on_btn_click(self):
        btn_index = self.sender()

        if btn_index == self.browse_train_data_file_path_btn:
            self.input_data_path_str = QFileDialog.getOpenFileName(self, '.pickle files', os.getenv('HOME'), '*.pickle')[0]
            self.train_data_file_path_le.setText(self.input_data_path_str)
        elif btn_index == self.run_model_btn:
            model_thread = ModelThread(data_input_path=self.input_data_path_str)
            model_thread.epoch_end_signal.connect(self.update_ui_on_epoch_end)
            self.progress_bar.show()
            model_thread.run()
            self.progress_bar.hide()

    def update_ui_on_epoch_end(self, current_epoch_num, error_lists):
        if current_epoch_num < 100:
            self.progress_bar.setValue(current_epoch_num)
        else:
            self.progress_bar.setValue(100)
        self.in_training_plot.plot_multi_data(x_axis_name='Epoch', y_axis_name='Error', plot_labels=['Train Accuracy', 'Validation Accuracy'], y_list=[error_lists[0], error_lists[1]])

    def run_model(self):
        if os.path.exists(self.train_data_file_path_str) and os.path.exists(self.output_data_path_str):
            train_thread = ModelThread(data_input_path='')
            train_thread.epoch_end_signal.connect(self.update_ui_on_epoch_end)
            # train_thread.start()
            self.progress_bar.show()
            train_thread.run()
            self.progress_bar.hide()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    main_menu = DashBoard()
    sys.exit(app.exec_())

1 个答案:

答案 0 :(得分:1)

您最初的错误是您不应该直接调用run而是启动,但是您的线程是一个局部变量,稍后会被删除。

代替创建自定义QThread更好的解决方案是创建一个驻留在另一个线程中的QObject并使用QTimer.singleShot调用该函数。

import os
import sys
from functools import partial
import pickle as pkl
from keras.layers import Input, Dense
from keras.callbacks import Callback
from keras.models import Model

from PyQt5 import QtCore, QtWidgets

import matplotlib
matplotlib.use('Qt5Agg')

from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

class Plot(FigureCanvas):
    def __init__(self, x_label, y_label, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)

        self.axes = fig.add_subplot(111)

        self.compute_initial_figure()
        self.axes.set_xlabel(x_label)
        self.axes.set_ylabel(y_label)

        FigureCanvas.__init__(self, fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        self.axes.set_xticks(range(1, 100, 10))


class MultiPlot(Plot):
    def __init__(self, parent=None, x_axis_name='X', y_axis_name='Y', width=5, height=4, dpi=100):
        super().__init__(x_axis_name, y_axis_name, parent, width, height, dpi)
        self.compute_initial_figure()

    def compute_initial_figure(self):
        self.axes.set_xticks(range(0, 100, 10))

    def plot_multi_data(self, x_axis_name='X', y_axis_name='Y', plot_labels=None, y_list=None):
        if y_list is not None:
            self.axes.clear()

            graph_handles = []

            markers = ['b:', 'r']
            y_index = 0

            for y in y_list:

                x = range(1, len(y) + 1)
                label = plot_labels[y_index]

                new_plot, = self.axes.plot(x, y, markers[y_index], markersize=2, label=label)
                graph_handles.append(new_plot)
                y_index += 1

                self.axes.set_xticks(x, int(len(list(x))/10))
            self.axes.legend(handles=graph_handles, loc=0, fontsize=8, shadow=True)

        self.axes.set_xlabel(x_axis_name)
        self.axes.set_ylabel(y_axis_name)

        self.draw()


class TrainPlotCallback(Callback):
    def __init__(self, signal):
        Callback.__init__(self)
        self.train_err = []
        self.val_err = []
        self.signal = signal

    def on_epoch_end(self, epoch, logs={}):
        self.train_err.append(1 - logs.get('acc'))
        self.val_err.append(1 - logs.get('val_acc'))
        self.signal.emit(epoch, [self.train_err, self.val_err])

def classification_model(data_input_path, on_epoch_end_signal):

    # ///////////////////// TEST /////////////////////
    if os.path.exists(data_input_path):
        plot_losses = TrainPlotCallback(on_epoch_end_signal)
        with open(data_input_path, 'rb') as pickle_in:
            data = pkl.load(pickle_in)
            X = data[0]
            y = data[1]

        input_size = X.shape[1]

        # MODEL CREATION
        # ///////////////////// INPUT LAYER /////////////////////
        inputs = Input(shape=(input_size,))
        # ///////////////////// INPUT LAYER /////////////////////
        # ///////////////////// HIDDEN LAYER /////////////////////
        x = Dense(10, activation='relu', kernel_initializer='normal')(inputs)   # THE FIRST LAYER
        # ///////////////////// HIDDEN LAYER /////////////////////
        # ///////////////////// OUTPUT LAYERS /////////////////////
        predictions = Dense(len(y[0]), activation='softmax')(x)  # the length of the output layer is as the length of the classes being predicted.
        # ///////////////////// OUTPUT LAYERS /////////////////////
        # MODEL CREATION

        # ///////////////////// MODEL DEFINITION /////////////////////
        model = Model(inputs=inputs, outputs=predictions)
        model.compile(optimizer='Adam',
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        # ///////////////////// MODEL DEFINITION /////////////////////

        # ///////////////////// MODEL TRAINING /////////////////////
        model.fit(X, y, validation_split=0.2, batch_size=100, epochs=100, callbacks=[plot_losses])
        # ///////////////////// MODEL TRAINING /////////////////////


class Worker(QtCore.QObject):
    started = QtCore.pyqtSignal()
    finished = QtCore.pyqtSignal()
    epoch_end_signal = QtCore.pyqtSignal(int, list)  # signal that has epoch # as the first parameter, and a list that contains the error values for the train and validation.

    @QtCore.pyqtSlot(str)
    def start_task(self, input_path):
        self.started.emit()
        classification_model(data_input_path=input_path,
                             on_epoch_end_signal=self.epoch_end_signal)
        self.finished.emit()


class DashBoard(QtWidgets.QWidget):
    def __init__(self):
        super().__init__()
        self.main_v_box = QtWidgets.QVBoxLayout(self)
        self.input_data_path_str = ''
        self.progress_bar = QtWidgets.QProgressBar()
        self.run_model_btn = QtWidgets.QPushButton('Run')
        self.browse_train_data_file_path_btn = QtWidgets.QPushButton('Browse')
        self.in_training_plot = MultiPlot(x_axis_name='Epoch Number', y_axis_name='Error')
        self.train_data_file_path_le = QtWidgets.QLineEdit()
        self.init()
        self.pack()
        self.showMaximized()

    def init(self):
        self.worker = Worker()
        thread = QtCore.QThread(self)
        thread.start()
        self.worker.moveToThread(thread)
        self.progress_bar.hide()
        self.browse_train_data_file_path_btn.clicked.connect(self.on_btn_click)
        self.run_model_btn.clicked.connect(self.on_btn_click)
        self.worker.epoch_end_signal.connect(self.update_ui_on_epoch_end)
        self.worker.started.connect(self.progress_bar.show)
        self.worker.finished.connect(self.progress_bar.hide)
        self.worker.started.connect(partial(self.run_model_btn.setEnabled, False))
        self.worker.finished.connect(partial(self.run_model_btn.setEnabled, True))

    def pack(self):
        self.main_v_box.addWidget(self.train_data_file_path_le)
        self.main_v_box.addWidget(self.browse_train_data_file_path_btn)
        self.main_v_box.addWidget(self.in_training_plot)
        self.main_v_box.addWidget(self.run_model_btn)
        self.main_v_box.addWidget(self.progress_bar)

    def on_btn_click(self):
        btn_index = self.sender()

        if btn_index == self.browse_train_data_file_path_btn:
            self.input_data_path_str, _ = QtWidgets.QFileDialog.getOpenFileName(self, '.pickle files', os.getenv('HOME'), '*.pickle')
            self.train_data_file_path_le.setText(self.input_data_path_str)
        elif btn_index == self.run_model_btn:
            QtCore.QTimer.singleShot(0, partial(self.worker.start_task, self.input_data_path_str))

    def update_ui_on_epoch_end(self, current_epoch_num, error_lists):
        if current_epoch_num < 100:
            self.progress_bar.setValue(current_epoch_num)
        else:
            self.progress_bar.setValue(100)
        self.in_training_plot.plot_multi_data(x_axis_name='Epoch', y_axis_name='Error', plot_labels=['Train Accuracy', 'Validation Accuracy'], y_list=[error_lists[0], error_lists[1]])


if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    main_menu = DashBoard()
    sys.exit(app.exec_())