我有一个tf.Module
类,其中包含一个(非可剥皮的)tf.keras.Model
作为子模块。我想知道在这种情况下推荐的tf.Module
序列化方式是什么?
我考虑了两种方法:
tf.keras.Model.save
类似的内容。我希望也许tf.Module
能够像tf.Model.save
一样保存嵌套模块。 tf.Module
尚未实现。tf.Module
的一种简单方法,但是我不能这样做,因为tf.keras.Model
是不可酸洗的。这是当前失败的示例代码:
import pickle
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
def main():
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
# Note, model *is not* picklable.
model = tf.keras.Model(x, y)
_ = model(tf.random.uniform((1, 3)))
module_1 = TestModule(model)
module_2 = pickle.loads(pickle.dumps(module_1))
for variable_1, variable_2 in zip(module_1.model.trainable_variables,
module_2.model.trainable_variables):
tf.debugging.assert_equal(variable_1, variable_2)
if __name__ == '__main__':
main()
我应该为每个__{get,set}state__
编写自定义的pickle功能(例如tf.Module
)还是创建与.save
相似的keras.Model
方法?
答案 0 :(得分:0)
您可以使用Saved Model Format保存自定义的tf.Module
子类。
以下适用于Tensorflow 2.1:
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
model = tf.keras.Model(x, y)
module_1 = TestModule(model)
tf.saved_model.save(module_1, "./foo")
要重新加载:
imported = tf.saved_model.load("foo")
断言
module_1 == imported
(或类似名称)将引发AssertionError
,因为加载后,我们正在处理其他Tensorflow对象。但是,我们可以遍历模型的权重并逐个比较它们的权重:
original_weights = module_1.model.weights
imported_weights = imported.model.variables.weights
for weight_idx, _ in enumerate(original_weights):
assert (
original_weights[weight_idx].numpy() == imported_weights[weight_idx].numpy()
).all()