我一直在努力用tf.Graphs
和tf.Sessions
管理多个Keras模型,这已经有几个星期了。简而言之,我想打开多个模型并根据需要在它们之间切换。这包括培训新模型,从文件打开并进行预测。
最重要的是:(几乎)一切正常,直到程序崩溃并退出代码0xC0000005
为止。没有给出错误信息。让我解释一下。
您明白了。这就是我目前管理图表和会话的方式。我使用上下文管理器将创建的图形和会话设置为默认值,然后再切换到先前的状态。
class NeuralNetwork:
def __init__(self):
self.graph = tf.Graph()
self.session = tf.Session(graph=self.graph)
self.model = None
def close(self):
self.session.close()
del self.graph
self.graph = None
gc.collect()
@contextmanager
def _context(self):
prev = k.get_session()
k.set_session(self.session)
with self.graph.as_default(), self.session.as_default():
yield
k.set_session(prev)
def predict(self, x):
with self._context():
return self.model.predict(x)
def fit(self, x_train, y_train, n=20, batch=256):
with self._context():
self.model.fit(x_train, y_train, epochs=n, batch_size=batch, verbose=0)
def create(self, shape):
with self._context():
self.model = Sequential()
self.model.add(Dense(shape[1], input_dim=shape[0], activation='relu'))
self.model.add(Dropout(drop))
self.model.add(Dense(shape[2], activation='sigmoid'))
self.model.compile(loss='binary_crossentropy', optimizer='rmsprop')
def load(self, path, sfx=''):
with open(path / ('architecture' + sfx + '.json'), 'r') as f:
js = f.read()
with self._context():
self.model = model_from_json(js)
self.model.load_weights(path / ('weights' + sfx + '.h5'))
self.model.compile(loss='binary_crossentropy', optimizer='rmsprop')
def save(self, path, sfx=''):
path.mkdir(exist_ok=True)
with self._context():
js = self.model.to_json()
with open(path / ('architecture' + sfx + '.json'), 'w') as f:
f.write(js)
self.model.save_weights(path / ('weights' + sfx + '.h5'))
对于上面的类,这是在其他地方使用网络的方式:
def create(self):
x, y = [], []
shape = (15, 30, 1)
self.predictor = NeuralNetwork()
self.predictor.create(shape)
self.predictor.fit(x, y)
self.predictor.save(path=self.path)
self.predictor.close()
def load(self):
self.predictor.load(path=self.path)
def predict(x):
# Executed only on loaded networks, never on created networks
# due to program structure
return self.predictor.predict(x)
这是我以前为阐明问题所做的努力。
尽我所能,并在某些人的帮助下,我试图找到一种管理这些资源的方法(上下文管理器,并在培训后“关闭”网络)。但是我没有看到详细描述Tensorflow或Keras资源管理过程的文档或教程。
我的目标有两个。
如果您能帮助我实现甚至朝着任一个方向迈进一小步,我将不胜感激!我的经验是,我的奋斗既不是独特的,也不是别人没有想到的。因此,我必须缺少正确的方法。
答案 0 :(得分:1)
已通过将所有软件包更新到最新版本解决了该问题。可悲的是,我一口气进行了升级,这意味着我不确定真正的原因是什么。但我愿意打赌Tensorflow。
以下是产生错误最可能涉及的软件包版本及其更新版本:
tensorflow==1.8.0 -> 1.12.0
numpy==1.14.5 -> 1.15.4
scikit-learn==0.19.1 -> 0.20.0