如何在不冻结GUI的情况下训练Keras模型?

时间:2019-07-18 10:30:50

标签: python tensorflow keras kivy

我正在尝试使用用于创建Keras模型的kivy框架创建GUI。 问题是当我尝试训练模型GUI时冻结。

根据链接(https://github.com/kivy/kivy/wiki/Working-with-Python-threads-inside-a-Kivy-application),我尝试使用线程来解决此问题。但是我编写的代码无法按预期工作。任何建议将不胜感激

.kv文件与链接上的文件相同,这是我的python代码:

import threading
import time

from kivy.app import App
from kivy.lang import Builder
from kivy.factory import Factory
from kivy.animation import Animation
from kivy.clock import Clock, mainthread
from kivy.uix.gridlayout import GridLayout

import tensorflow as tf
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import matplotlib.pyplot as plt

mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
x_train, x_test = x_train.reshape(60000,28,28,1) / 255.0, x_test.reshape(10000,28,28,1) / 255.0
input_shape = (28, 28, 1)

model = Sequential([
                Flatten(input_shape=input_shape),
            Dense(10, activation='softmax')
        ])

class UpdateProgressBar(Callback):
    epoch = 0
    def __init__(self, th):
        self.th = th

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch
        self.th.update_label_text(str(epoch))

class RootWidget(GridLayout):

    stop = threading.Event()

    def start_second_thread(self, l_text):
        threading.Thread(target=self.second_thread, args=(l_text,),daemon = True).start()

    def second_thread(self, label_text):
        Clock.schedule_once(self.start_test, 0)

        self.stop_test()
        # Start a new thread with an infinite loop and stop the current one.
        threading.Thread(target=self.infinite_loop).start()

    def start_test(self, *args):
        global model
        model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
        model.fit(x_train, y_train,batch_size = 64, epochs=5, verbose=1,callbacks =[UpdateProgressBar(self)])

    @mainthread
    def update_label_text(self, new_text):
        self.lab_2.text = new_text

    @mainthread
    def stop_test(self):
        self.lab_1.text = ('Second thread exited, a new thread has started. '
                           'Close the app to exit the new thread and stop '
                           'the main process.')

        self.lab_2.text = str(int(self.lab_2.text) + 1)

    def infinite_loop(self):
        iteration = 0
        while True:
            if self.stop.is_set():
                # Stop running this thread so the main Python process can exit.
                return
            iteration += 1
            print('Infinite loop, iteration {}.'.format(iteration))
            time.sleep(1)

buildKV = Builder.load_file("thread.kv")

class ThreadedApp(App):
    def on_stop(self):
        self.root.stop.set()
    def build(self):
        return RootWidget()

if __name__ == '__main__':
    ThreadedApp().run()

1 个答案:

答案 0 :(得分:0)

Furas是正确的,不需要使用时钟调度。 问题是TensorFlow图。该图应先前保存为全局图,并且线程函数应类似于:

def second_thread(self, ):
        with graph.as_default():
            model.compile(optimizer='adam',
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
            model.fit(x_train, y_train,
                        batch_size = 60,
                        epochs=5,
                        verbose=1,
                        callbacks =[UpdateProgressBar(self)])
        self.stop_test()
        threading.Thread(target=self.infinite_loop).start()

在以下链接上找到解决方案:https://github.com/keras-team/keras/issues/2397#issuecomment-254919212