我想在tf.keras中使用回调保存模型,但是出现以下错误,
Epoch 1/500
99/100 [============================>.] - ETA: 0s - loss: 1.4361 - acc: 0.5970
Epoch 00001: val_loss improved from inf to 1.19811, saving model to ./params/test.hdf5
Traceback (most recent call last):
File "main.py", line 61, in <module>
train(parser)
File "main.py", line 52, in train
callbacks=[callback]
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 2177, in fit_generator
initial_epoch=initial_epoch)
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py", line 216, in fit_generator
callbacks.on_epoch_end(epoch, epoch_logs)
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 214, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 590, in on_epoch_end
self.model.save(filepath, overwrite=True)
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1363, in save
save_model(self, filepath, overwrite, include_optimizer)
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/site-packages/tensorflow/python/keras/engine/saving.py", line 103, in save_model
default=serialization.get_json_type).encode('utf8')
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/json/__init__.py", line 238, in dumps
**kw).encode(obj)
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/json/encoder.py", line 199, in encode
chunks = self.iterencode(o, _one_shot=True)
File "/home/yudai/.conda/envs/tf-gpu/lib/python3.6/json/encoder.py", line 257, in iterencode
return _iterencode(o, 0)
ValueError: Circular reference detected
如您所见,训练进行得很好,但是错误发生在回调中。
我用以下代码定义ML模型,
import tensorflow as tf
class subbranch:
def __call__(self, inputs):
return self.sub_block(inputs)
def sub_block(self, x):
x = tf.keras.layers.Conv2D(
filters=1024,
kernel_size=3,
padding="SAME",
activation=tf.keras.activations.relu
)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(
filters=1024,
kernel_size=3,
padding="SAME",
activation=tf.keras.activations.relu
)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(
filters=10,
kernel_size=1,
padding="SAME",
activation=tf.keras.activations.relu
)(x)
x = tf.keras.layers.BatchNormalization()(x)
tmp = x
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Softmax()(x)
return x, tmp
class Adversarial(tf.keras.layers.Layer):
def __init__(self):
super(Adversarial, self).__init__(self)
def call(self, inputs, **kwargs):
vgg_end, interm, branchA_end = inputs
max_idx = tf.argmax(branchA_end, axis=1)
tmp = []
#batch_size = kwargs["batch_size"]
batch_size = 10
for bt in range(batch_size):
a = tf.reshape(interm[bt, :, :, max_idx[bt]], [7, 7, 1])
each = tf.tile(a, [1, 1, 512])
#each = tf.stack([interm[bt, :, :, max_idx[bt]] for i in range(512)], axis=-1)
tmp.append(each)
tmp = tf.stack(tmp)
tmp = tf.where(tmp > 0.9, tmp, tmp*0)
adv = tf.subtract(vgg_end, tmp)
return vgg_end
class sum_up(tf.keras.layers.Layer):
def __init__(self):
super(sum_up, self).__init__(self)
def call(self, inputs, **kwargs):
x, y = inputs
return tf.add(x, y)
def ACoL(args):
n_classes = args.n_classes
batch_size = args.batch_size
vgg16 = tf.keras.applications.VGG16(
include_top=False,
input_shape=(224, 224, 3),
classes=n_classes
)
for layer in vgg16.layers:
layer.trainable = False
x = vgg16.output
#branch-A
x, tmp = subbranch()(x)
#branch-B
y = Adversarial()([vgg16.output, tmp, x])
y, _ = subbranch()(y)
output = sum_up()([x, y])
return tf.keras.Model(inputs=vgg16.input, outputs=output)
而且,当我训练该模型时,我使用了以下代码
callback = tf.keras.callbacks.ModelCheckpoint(
filepath="./params/test.hdf5",
monitor="val_loss",
verbose=1,
save_best_only=True,
mode="auto"
)
model = ACoL(args)
model.compile(
optimizer=tf.train.AdamOptimizer(),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
model.fit_generator(
generator,
steps_per_epoch=100,
epochs=epoch,
validation_data=val_generator,
validation_steps=10,
callbacks=[callback]
)
没有callbacks=[callback]
行,没有错误发生。错误消息显示“循环引用”存在,但我找不到对应的位置。