TypeError :(&#39; Not JSON Serializable:&#39;,<tf.variable'variable_1:0'=“”shape =“()”dtype =“float32_ref”>)

时间:2018-03-14 08:12:22

标签: python tensorflow keras

我正在尝试使用自适应# callback for adaptive loss_weight class LossWeightCallback(Callback): def __init__(self, alpha): self.alpha = alpha def on_epoch_end(self, epoch, logs={}): self.alpha = self.alpha * 0.9 # initial loss_weight alpha = K.variable(10) # model img_input = Input(shape=(224, 224, 3), name='input') ... model = Model(inputs=img_input, outputs=[y1, y2]) # compile model.compile(keras.optimizers.SGD(lr=1e-4, momentum=0.9), loss={'output1': 'categorical_crossentropy', 'output2': 'mse'}, loss_weights={'output1': 1, 'output2': alpha}, metrics={'output1': 'accuracy', 'output2': 'mse'}) # Fit model checkpointer = ModelCheckpoint('multitask_model.h5', monitor='val_output1_acc', verbose=1, save_best_only=True) results = model.fit(x_train, {'output1': y_train1, 'output2': y_train2}, validation_split=0.1, batch_size= 100, epochs=50, callbacks=[checkpointer, LossWeightCallback(alpha)]) 来实现多任务CNN,随着时代的增加而衰减。我提到了Github issue

TypeError: ('Not JSON Serializable:', <tf.Variable 'Variable_1:0' shape=() dtype=float32_ref>)

但是此代码在第1个纪元结束后返回错误:

Dim myPicture As Picture 'embedded pic

这个错误有什么解决方法吗? 提前谢谢。

2 个答案:

答案 0 :(得分:0)

这不是一个完美的答案,但是当我从SELECT ExpectationId, ExpectationName, ParentName FROM ExpectationsView WHERE FREETEXT(ExpectationName, @Keyword) OR FREETEXT(ParentName, @Keyword) checkpointer中移除callbacks时,错误就会消失,而且代码运行良好。

model.fit()

使用results = model.fit(x_train, {'output1': y_train1, 'output2': y_train2}, validation_split=0.1, batch_size= 100, epochs=50, callbacks=[LossWeightCallback(alpha)]) 并使用自定义回调时on_epoch_end()函数似乎发生冲突......

答案 1 :(得分:0)

最近,我也遇到了这个问题,幸运的是找到了一个简单的解决方案。在类型错误中,tf.variable不能进行json序列化,因此请尝试使其不为tf.variable类型。请使用以下代码替换原始的编译代码:

model.compile(keras.optimizers.SGD(lr=1e-4, momentum=0.9), 
              loss={'output1': 'categorical_crossentropy', 'output2': 'mse'},
              loss_weights={'output1': 1, 'output2': alpha.numpy()},
              metrics={'output1': 'accuracy', 'output2': 'mse'})