如何冻结参数

时间:2017-06-16 13:39:03

标签: python cntk

我需要在训练期间冻结一些图层的参数。我尝试按needs_gradient设置model.L1.b.needs_gradient = False属性,但我得到以下异常:

AttributeError Traceback (most recent call last)
<ipython-input-57-93ef31fae7d8> in <module>()
----> 1 model.L1.b.needs_gradient = False

/home/aj/anaconda3/envs/cntk-py27/lib/python2.7/site-packages/cntk/cntk_py.pyc in <lambda>(self, name, value)
   1263     for _s in [Variable]:
   1264         __swig_setmethods__.update(getattr(_s, '__swig_setmethods__', {}))
-> 1265     __setattr__ = lambda self, name, value: _swig_setattr(self, Parameter, name, value)
   1266     __swig_getmethods__ = {}
   1267     for _s in [Variable]:

/home/aj/anaconda3/envs/cntk-py27/lib/python2.7/site-packages/cntk/cntk_py.pyc in _swig_setattr(self, class_type, name, value)
     72 
     73 def _swig_setattr(self, class_type, name, value):
---> 74     return _swig_setattr_nondynamic(self, class_type, name, value, 0)
     75 
     76 

/home/aj/anaconda3/envs/cntk-py27/lib/python2.7/site-packages/cntk/cntk_py.pyc in _swig_setattr_nondynamic(self, class_type, name, value, static)
     64     if (not static):
     65         if _newclass:
---> 66             object.__setattr__(self, name, value)
     67         else:
     68             self.__dict__[name] = value

AttributeError: can't set attribute

请帮我消除异常或其他冻结参数的方法。 感谢

2 个答案:

答案 0 :(得分:3)

您可以在声明时使用needs_gradient属性将CNTK变量声明为“冻结”。它将与声明常量相同。但是,如果您想让网络训练一段时间然后冻结参数并在其他培训中使用它,您可以使用这种方法实现:

import cntk
trained_model = get_my_previously_trained_model()
frozen_model = trained_model.clone(cntk.CloneMethod.freeze)
output_from_trained_model = frozen_model(input_features)
model = cntk.layers.Dense(output_dim, activation=None)(output_from_trained_model)

现在,只有Dense图层中的参数可以训练。整个网络中的其余参数都被冻结。希望这个例子有帮助。

答案 1 :(得分:3)

避免更新您不想要的参数的另一种方法是创建学习者并传递需要学习的参数。您可以查看使用此功能的GAN example来切换G和D网络之间的培训。