如何使用自定义图层Tensorflow2保存自定义模型

时间:2020-07-15 17:58:33

标签: python tensorflow deep-learning neural-network model

我正在尝试减轻模型的重量:model.save_weights('filepath')

我的模特是

class CLayer(Layer):
    def __init__(self, hidden_units=[]):
        super(PartCoder, self).__init__()
        # self.layers = NoDependency([])
        # self.__dict__['layers'] = []
        self.layers = []
        for u in hidden_units:
            self.layers.append(keras.layers.Dense(u))

class CModel(Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.cLayers = CLayer(hidden_units=hidden_units)


model_1 = CModel(hidden_units=[2,4])  
model_1.save_weights("filepath_1")  #--> worked

model_2 = CModel(hidden_units=[2,4])
model_2.load_weights("filepaht_1")  #--> worked

#train model_2
model_2.save_weights("filepath_2") #--> Crash here

结果:

Epoch 5: Loss = 72.76963806152344
Epoch 10: Loss = 68.98507690429688
Epoch 15: Loss = 65.25872039794922
Epoch 20: Loss = 61.60311508178711
Epoch 25: Loss = 58.087005615234375

但是有错误:

ValueError: Unable to save the object ListWrapper([<tensorflow.python.keras.layers.core.Dense object at 0x7f032829d4e0>, <tensorflow.python.keras.layers.core.Dense object at 0x7f032829da90>, <tensorflow.python.keras.layers.core.Dense object at 0x7f032ad38e80>, <tensorflow.python.keras.layers.core.Dense object at 0x7f0328165048>]) (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced (__setitem__, __setslice__), deleted (__delitem__, __delslice__), or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures.

If you don't need this list checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency object; it will be automatically un-wrapped and subsequently ignored.

在Google上找到之后,我尝试使用:

self.__dict__['layers'] = []
#OR: self.layers = NoDependency([])

用于保存CLayer类中的图层列表。 该解决方案是可行的,但图层中的权重不会更新。

Epoch 5: Loss = 76.57701873779297
Epoch 10: Loss = 76.57701873779297
Epoch 15: Loss = 76.57701873779297
Epoch 20: Loss = 76.57701873779297

链接:Tried solution

我的问题是如何将带有自定义图层的模型保存为上述示例模型而没有错误。

1 个答案:

答案 0 :(得分:0)

经过很长时间,我会尽一切可能进行搜索和修复。然后,我搬到了Pytorch。保存模型没有错误更容易。 最后,我通过使用Pytorch解决了问题,并且有效。