我可以在keras定制模型中使用字典吗?

时间:2019-08-16 01:38:22

标签: keras tensorflow2.0

我最近阅读了一篇有关UNet ++的论文,我想用tensorflow-2.0和keras定制模型实现这种结构。由于结构非常复杂,我决定通过字典来管理keras图层。在训练中一切都进行得很好,但是在保存模型时发生了错误。这是显示错误的最小代码:

class DicModel(tf.keras.Model):
    def __init__(self):
        super(DicModel, self).__init__(name='SequenceEECNN')
        self.c = {}
        self.c[0] = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3,activation='relu',padding='same'),
            tf.keras.layers.BatchNormalization()]
        )
        self.c[1] = tf.keras.layers.Conv2D(3,3,activation='softmax',padding='same')
    def call(self,images):
        x = self.c[0](images)
        x = self.c[1](x)
        return x

X_train,y_train = load_data()
X_test,y_test = load_data()

class_weight.compute_class_weight('balanced',np.ravel(np.unique(y_train)),np.ravel(y_train))

model = DicModel()
model_name = 'test'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs/'+model_name+'/')
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=100,mode='min')

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['accuracy'])

results = model.fit(X_train,y_train,batch_size=4,epochs=5,validation_data=(X_test,y_test),
                    callbacks=[tensorboard_callback,early_stop_callback],
                    class_weight=[0.2,2.0,100.0])

model.save_weights('model/'+model_name,save_format='tf')

错误信息是:

Traceback (most recent call last):

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/learn_tf2/test_model.py", line 61, in \<module>

    model.save_weights('model/'+model_name,save_format='tf')

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1328, in save_weights

    self.\_trackable_saver.save(filepath, session=session)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1106, in save

    file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1046, in \_save_cached_when_graph_building

    object_graph_tensor=object_graph_tensor)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1014, in \_gather_saveables

    feed_additions) = self.\_graph_view.serialize_object_graph()

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 379, in serialize_object_graph

    trackable_objects, path_to_root = self.\_breadth_first_traversal()

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 199, in \_breadth_first_traversal

    for name, dependency in self.list_dependencies(current_trackable):

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 159, in list_dependencies

    return obj.\_checkpoint_dependencies

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 690, in \_\_getattribute\_\_

    return object.\_\_getattribute\_\_(self, name)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 732, in \_checkpoint_dependencies

    "ignored." % (self,))

ValueError: Unable to save the object {0: \<tensorflow.python.keras.engine.sequential.Sequential object at 0x7fb5c6c36588>, 1: \<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb5c6c36630>} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.



If you don't need this dictionary checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency object; it will be automatically un-wrapped and subsequently ignored.

tf.contrib.checkpoint.NoDependency似乎已从Tensorflow-2.0(https://medium.com/tensorflow/whats-coming-in-tensorflow-2-0-d3663832e9b8)中删除。如何解决此问题?还是我应该放弃在定制的Keras模型中使用字典。谢谢您的时间和帮助!

2 个答案:

答案 0 :(得分:1)

使用字符串键。由于某种原因,tensorflow不喜欢int键。

答案 1 :(得分:0)

Tensorflow 2.0中的异常消息不正确,并且已在2.2中修复

您可以通过包装c属性来避免此问题

from tensorflow.python.training.tracking.data_structures import NoDependency
self.c = NoDependency({})

有关更多详细信息,请检查this issue