class MyModel(Model):
def __init__(self,num_classes=1):
super(MyModel, self).__init__()
self.conv1=Convolution2D(filters=8,kernel_size=8,padding='same')
self.batch_norm1=BatchNormalization()
self.activation1=Activation('relu')
self.conv2=Convolution2D(filters=16,kernel_size=8,activation='relu',padding='same')
self.batch_norm2=BatchNormalization()
self.activation2=Activation('relu')
self.MaxPooling2D=MaxPooling2D(pool_size =(2, 2))
self.Flatten=Flatten()
self.dense1=Dense(16,activation='relu')
self.dense2=Dense(num_classes,kernel_regularizer=regularizers.l2(0.4))
def call(self,inputs):
x=self.conv1(inputs)
x=self.batch_norm1(x)
x=self.activation1(x)
x=self.conv2(x)
x=self.batch_norm2(x)
x=self.activation2(x)
x=self.MaxPooling2D(x)
x=self.Flatten(x)
x=self.dense1(x)
return self.dense2(x)
def compute_output_shape(self, input_shape):
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.num_classes
return tf.TensorShape(shape)
model=MyModel()
adam=Adam(learning_rate=1e-4)
model.compile(optimizer=adam,loss="mse")
earlystopper = EarlyStopping(monitor='val_loss', patience=20, verbose=0)
checkpoint =ModelCheckpoint("C:/Users/user/Desktop/research/pic_recognition/cnn2d-model.hdf5",save_best_only=True)
callback_list=[earlystopper,checkpoint]
model.fit(x_train, y_train, epochs=50, batch_size=8,validation_split=0.1,callbacks=callback_list)
但我收到此错误:
文件“”,第46行,在 model.fit(x_train,y_train,epochs = 50,batch_size = 8,validation_split = 0.1,callbacks = callback_list) 适合的文件“ D:\ Anaconda3 \ lib \ site-packages \ keras \ engine \ training.py”,行1239 validate_freq = validation_freq) fit_loop中的文件“ D:\ Anaconda3 \ lib \ site-packages \ keras \ engine \ training_arrays.py”,第216行 callbacks.on_epoch_end(epoch,epoch_logs) 文件“ D:\ Anaconda3 \ lib \ site-packages \ keras \ callbacks \ callbacks.py”,第152行,位于on_epoch_end中 callback.on_epoch_end(epoch,logs) 文件“ D:\ Anaconda3 \ lib \ site-packages \ keras \ callbacks \ callbacks.py”,行719,位于on_epoch_end中 self.model.save(文件路径,覆盖=真) 保存文件“ D:\ Anaconda3 \ lib \ site-packages \ keras \ engine \ network.py”,行1150 引发NotImplementedError NotImplementedError
答案 0 :(得分:0)
对于自定义模型,您必须对ModelCheckpoint()使用“ save_weights_only = True”或使用model.save_weights()
有关更多详细信息,请参见以下链接: