当尝试使用自定义约束类加载保存的模型时,出现问题。
我的班级如下:
class WeightsOrthogonalityConstraint (Constraint):
def __init__(self, encoding_dim, weightage = 1.0, axis = 0):
self.encoding_dim = encoding_dim
self.weightage = weightage
self.axis = axis
def weights_orthogonality(self, w):
if(self.axis==1):
w = K.transpose(w)
if(self.encoding_dim > 1):
m = K.dot(K.transpose(w), w) - K.eye(self.encoding_dim)
return self.weightage * K.sqrt(K.sum(K.square(m)))
else:
m = K.sum(w ** 2) - 1.
return m
def __call__(self, w):
return self.weights_orthogonality(w)
def get_config(self):
return {
'encoding_dim': self.encoding_dim,
'weightage': self.weightage,
'axis': self.axis
}
并且,我像这样使用此类:
encoder = Dense(encoding_dim, activation="linear",
#input_shape=(input_dim,),
use_bias = True,
kernel_regularizer=WeightsOrthogonalityConstraint(encoding_dim, weightage=1., axis=0),
kernel_constraint=UnitNorm(axis=0))
然后,我这样称呼load_model:
autoencoder=load_model('anomaly-detection_Fully.h5',custom_objects={'WeightsOrthogonalityConstraint': WeightsOrthogonalityConstraint(16)})
但是,发生错误
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\utils\generic_utils.py", line 154, in deserialize_keras_object
return cls(**config['config'])
TypeError: __call__() missing 1 required positional argument: 'w'
我怎么会通过“ w”?