我尝试使用google colab资源来保存CNN模型权重,但出现此错误。我尝试使用Google搜索,但是没有帮助。
“顺序”对象没有属性“ _in_multi_worker_mode”
我的代码:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1)
cnn_model = Sequential()
cnn_model.add(Conv2D(filters = 64, kernel_size = (3,3), activation = "relu", input_shape = Input_shape ))
cnn_model.add(Conv2D(filters = 64, kernel_size = (3,3), activation = "relu"))
cnn_model.add(MaxPooling2D(2,2))
cnn_model.add(Dropout(0.4))
cnn_model = Sequential()
cnn_model.add(Conv2D(filters = 128, kernel_size = (3,3), activation = "relu"))
cnn_model.add(Conv2D(filters = 128, kernel_size = (3,3), activation = "relu"))
cnn_model.add(MaxPooling2D(2,2))
cnn_model.add(Dropout(0.3))
cnn_model.add(Flatten())
cnn_model.add(Dense(units = 512, activation = "relu"))
cnn_model.add(Dense(units = 512, activation = "relu"))
cnn_model.add(Dense(units = 10, activation = "softmax"))
history = cnn_model.fit(X_train, y_train, batch_size = 32,epochs = 1,
shuffle = True, callbacks = [cp_callback])
堆栈跟踪:
AttributeError Traceback (most recent call last)
<ipython-input-19-35c1db9636b7> in <module>()
----> 1 history = cnn_model.fit(X_train, y_train, batch_size = 32,epochs = 1, shuffle = True, callbacks = [cp_callback])
4 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in on_train_begin(self, logs)
903 def on_train_begin(self, logs=None):
904 # pylint: disable=protected-access
--> 905 if self.model._in_multi_worker_mode():
906 # MultiWorkerTrainingState is used to manage the training state needed
907 # for preemption-recovery of a worker in multi-worker training.
AttributeError: 'Sequential' object has no attribute '_in_multi_worker_mode'
答案 0 :(得分:3)
检查您的张量流版本。您实际上只需要同步它。检查您所有的导入是否使用
from keras import ...
或
from tensorflow.keras import ...
仅将上述方法之一用于您的keras导入。同时使用两个(两个)可能会导致库冲突。
答案 1 :(得分:2)
代替
tf.keras.callbacks.ModelCheckpoint
在您的模型构建过程中,您可以使用
from keras.callbacks import ModelCheckpoint
为了导入ModelCheckpoint
,然后在后面的代码中使用ModelCheckpoint
。
答案 2 :(得分:0)
我最近也遇到了同样的问题
代替
from tensorflow.keras.callbacks import ModelCheckpoint
使用
from keras.callbacks import ModelCheckpoint
答案 3 :(得分:0)
请检查您的tensorflow版本是否与最新版本相匹配。在我看来,该错误在更新至2.1.0时已解决。