培训期间将芹菜与芹菜配对悬挂

时间:2018-09-12 11:59:51

标签: django tensorflow keras celery

我有使用Keras训练神经网络的代码。

首先,我通过类似于网格搜索的算法运行它以获取最佳参数,然后使用最佳参数进行实际预测。

代码本身完全可以正常工作;直到我用celery(Django后端)运行它之后,我才开始出现标题中所述的问题。

为澄清起见,训练在网格搜索过程中完全可以正常工作,但是当训练完成并再次以最佳参数运行时,它只是挂在“ Epoch 1/1”上。

在研究中,我读到我需要将keras导入限制在一个区域,因为它不适用于多处理。但是,我确定是这种情况,甚至尝试将keras导入放入训练它的函数中,但仍然存在相同的问题。

编辑:我尝试了以下操作,但仍然遇到相同的问题...

@worker_process_init.connect()
def init_worker_process(**kwargs):

    import tensorflow as tf
    session_conf = tf.ConfigProto(
       intra_op_parallelism_threads=1,
       inter_op_parallelism_threads=1
    )

    import keras
    from keras import backend as K
    K.clear_session()
    tf.reset_default_graph()
    tf.set_random_seed(7)
    sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
    K.set_session(sess)

    from keras.models import Sequential
    from keras.layers import Dense, Dropout, LSTM, GRU
    from keras.callbacks import EarlyStopping
    from keras.optimizers import Adam, Nadam, Adamax, RMSprop
    from keras import regularizers
    from keras.initializers import glorot_uniform, Orthogonal

@shared_task
def myfunc(inputs...):
    #...call function that performs everything, including training with keras

我的问题是:如何使Keras与芹菜一起正常工作?

1 个答案:

答案 0 :(得分:0)

GRAPH = tf.get_default_graph()

@shared_task
def myfunc(inputs...):
    with GRAPH.as_default():
       model_loaded = model.create_model() // define archs, load weight etc
       model_loaded.fit(data,labels ....)

tensorflow是特定于线程的数据结构,因此您不能在芹菜中使用tf.load_model,tf.fit等作为进程驱动的拱门。

为此,您需要使用eventletgevent库。

  

pip安装gevent

以芹菜身份运行:

  

芹菜-我的网站工作人员-pool gevent -l信息