将model.summary()编写为函数

时间:2018-09-03 06:09:04

标签: python-2.7 keras

我有一个写为函数的模型,如下所示:

def __baseModel(self, nodes=300, lr=0.001):
        _model = Sequential()
        _model.add(Dense(nodes, input_dim=self.inputDim, kernel_initializer='he_normal', activation='relu'))
        _model.add(Dense(1, kernel_initializer='normal', activation='sigmoid'))
        _optimizer = self.get_optimizer(learn_rate=lr, dcy=float(self.cfg['init_decay']), eps=float(self.cfg['init_epsilon']))
        _model.compile(loss='binary_crossentropy', optimizer='Nadam', metrics=['accuracy'])
        return _model

此函数位于名为classifier.py的python文件中。我将在另一个名为demodel.py的python文件中调用此函数,该文件具有模型上的训练/预测代码。最后,在main.py中,我想调用一个函数,在其中可以仅打印模型摘要。我的问题是,我应该在哪个文件中编写一个可以打印出模型摘要的函数,以及如何编写这样的函数?

1 个答案:

答案 0 :(得分:1)

您可以为此使用面向对象的概念,那么您不必编写自己的model.summary函数

基本上,对 keras 模型使用 Singleton设计模式/类,然后您可以从任何其他文件/模块/类访问此Singleton。

您在main.py中仅创建此Singleton类的一个实例,在任何其他文件/模块中,您只能访问在main.py中创建的单个实例(因此,Singleton) ,这样您就可以从整个程序中访问相同的 keras模型

modelsingleton.py

class ModelSingleton():
# Here will be the instance stored.
__instance = None



@staticmethod
def getInstance():
    """ Static access method. """
    if ModelSingleton.__instance == None:
        ModelSingleton()
    return ModelSingleton.__instance 

def __init__(self):

    self.model = self.baseModel()

    """ Virtually private constructor. """
    if ModelSingleton.__instance != None:
        raise Exception("This class is a singleton!")
    else:
        ModelSingleton.__instance = self


def baseModel(self):
   nodes=300
   lr=0.001
   model = Sequential() 
   model.add(Dense(nodes, input_dim=4, kernel_initializer='he_normal', activation='relu'))
   model.add(Dense(1, kernel_initializer='normal', activation='sigmoid')) 
   #optimizer = self.get_optimizer(learn_rate=lr, dcy=float(self.cfg['init_decay']), eps=float(self.cfg['init_epsilon']))
   model.compile(loss='binary_crossentropy', optimizer='Nadam', metrics=['accuracy'])
   print('model generated')
   return model

main.py (在这里您创建了ModelSingleton的单个实例,定义了您的keras模型)

import modelsingleton
import demodel

model_instance = modelsingleton.ModelSingleton()

model_instance.model.summary()

demodel.defineModel()

demodel.py (在这种情况下,您获取ModelSingleton创建main.py的实例

import modelsingleton


def defineModel():
    model_instance = modelsingleton.ModelSingleton.getInstance()

    model_instance.baseModel()

    print('summary demodel')
    model_instance.model.summary()

上面的代码简单地打印了两个模型摘要,从main.py开始,另一个从demodel.py开始,然后从baseModel()调用demodel.py(通过{中的defineModel() {1}})

我改编了https://gist.github.com/pazdera/1098129中的代码,也可以写出Singleton

https://python-3-patterns-idioms-test.readthedocs.io/en/latest/Singleton.html

中有一个替代方法