我有一个写为函数的模型,如下所示:
def __baseModel(self, nodes=300, lr=0.001):
_model = Sequential()
_model.add(Dense(nodes, input_dim=self.inputDim, kernel_initializer='he_normal', activation='relu'))
_model.add(Dense(1, kernel_initializer='normal', activation='sigmoid'))
_optimizer = self.get_optimizer(learn_rate=lr, dcy=float(self.cfg['init_decay']), eps=float(self.cfg['init_epsilon']))
_model.compile(loss='binary_crossentropy', optimizer='Nadam', metrics=['accuracy'])
return _model
此函数位于名为classifier.py
的python文件中。我将在另一个名为demodel.py
的python文件中调用此函数,该文件具有模型上的训练/预测代码。最后,在main.py
中,我想调用一个函数,在其中可以仅打印模型摘要。我的问题是,我应该在哪个文件中编写一个可以打印出模型摘要的函数,以及如何编写这样的函数?
答案 0 :(得分:1)
您可以为此使用面向对象的概念,那么您不必编写自己的model.summary
函数
基本上,对 keras 模型使用 Singleton设计模式/类,然后您可以从任何其他文件/模块/类访问此Singleton。
您在main.py
中仅创建此Singleton类的一个实例,在任何其他文件/模块中,您只能访问在main.py
中创建的单个实例(因此,Singleton) ,这样您就可以从整个程序中访问相同的 keras模型:
modelsingleton.py
class ModelSingleton():
# Here will be the instance stored.
__instance = None
@staticmethod
def getInstance():
""" Static access method. """
if ModelSingleton.__instance == None:
ModelSingleton()
return ModelSingleton.__instance
def __init__(self):
self.model = self.baseModel()
""" Virtually private constructor. """
if ModelSingleton.__instance != None:
raise Exception("This class is a singleton!")
else:
ModelSingleton.__instance = self
def baseModel(self):
nodes=300
lr=0.001
model = Sequential()
model.add(Dense(nodes, input_dim=4, kernel_initializer='he_normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal', activation='sigmoid'))
#optimizer = self.get_optimizer(learn_rate=lr, dcy=float(self.cfg['init_decay']), eps=float(self.cfg['init_epsilon']))
model.compile(loss='binary_crossentropy', optimizer='Nadam', metrics=['accuracy'])
print('model generated')
return model
main.py (在这里您创建了ModelSingleton的单个实例,定义了您的keras模型)
import modelsingleton
import demodel
model_instance = modelsingleton.ModelSingleton()
model_instance.model.summary()
demodel.defineModel()
demodel.py (在这种情况下,您获取ModelSingleton创建main.py
的实例)
import modelsingleton
def defineModel():
model_instance = modelsingleton.ModelSingleton.getInstance()
model_instance.baseModel()
print('summary demodel')
model_instance.model.summary()
上面的代码简单地打印了两个模型摘要,从main.py
开始,另一个从demodel.py
开始,然后从baseModel()
调用demodel.py
(通过{中的defineModel()
{1}})
我改编了https://gist.github.com/pazdera/1098129中的代码,也可以写出Singleton
https://python-3-patterns-idioms-test.readthedocs.io/en/latest/Singleton.html
中有一个替代方法