我有使用Keras训练神经网络的代码。
首先,我通过类似于网格搜索的算法运行它以获取最佳参数,然后使用最佳参数进行实际预测。
代码本身完全可以正常工作;直到我用celery(Django后端)运行它之后,我才开始出现标题中所述的问题。
为澄清起见,训练在网格搜索过程中完全可以正常工作,但是当训练完成并再次以最佳参数运行时,它只是挂在“ Epoch 1/1”上。
在研究中,我读到我需要将keras导入限制在一个区域,因为它不适用于多处理。但是,我确定是这种情况,甚至尝试将keras导入放入训练它的函数中,但仍然存在相同的问题。
编辑:我尝试了以下操作,但仍然遇到相同的问题...
@worker_process_init.connect()
def init_worker_process(**kwargs):
import tensorflow as tf
session_conf = tf.ConfigProto(
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1
)
import keras
from keras import backend as K
K.clear_session()
tf.reset_default_graph()
tf.set_random_seed(7)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
from keras.models import Sequential
from keras.layers import Dense, Dropout, LSTM, GRU
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam, Nadam, Adamax, RMSprop
from keras import regularizers
from keras.initializers import glorot_uniform, Orthogonal
@shared_task
def myfunc(inputs...):
#...call function that performs everything, including training with keras
我的问题是:如何使Keras与芹菜一起正常工作?
答案 0 :(得分:0)
GRAPH = tf.get_default_graph()
@shared_task
def myfunc(inputs...):
with GRAPH.as_default():
model_loaded = model.create_model() // define archs, load weight etc
model_loaded.fit(data,labels ....)
tensorflow是特定于线程的数据结构,因此您不能在芹菜中使用tf.load_model,tf.fit等作为进程驱动的拱门。
pip安装gevent
以芹菜身份运行:
芹菜-我的网站工作人员-pool gevent -l信息