推荐序列化tf.Module的推荐方法是什么?

时间:2020-01-17 17:00:02

标签: python tensorflow keras tensorflow2.0

我有一个tf.Module类,其中包含一个(非可剥皮的)tf.keras.Model作为子模块。我想知道在这种情况下推荐的tf.Module序列化方式是什么?

我考虑了两种方法:

  1. 使用与tf.keras.Model.save类似的内容。我希望也许tf.Module能够像tf.Model.save一样保存嵌套模块。 tf.Module尚未实现。
  2. 酸洗,这是序列化tf.Module的一种简单方法,但是我不能这样做,因为tf.keras.Model是不可酸洗的。

这是当前失败的示例代码:

import pickle

import tensorflow as tf


class TestModule(tf.Module):
    def __init__(self, model):
        self.model = model


def main():
    x = tf.keras.layers.Input((3, ))
    y = tf.keras.layers.Dense(5)(x)
    # Note, model *is not* picklable.
    model = tf.keras.Model(x, y)

    _ = model(tf.random.uniform((1, 3)))

    module_1 = TestModule(model)
    module_2 = pickle.loads(pickle.dumps(module_1))

    for variable_1, variable_2 in zip(module_1.model.trainable_variables,
                                      module_2.model.trainable_variables):
        tf.debugging.assert_equal(variable_1, variable_2)


if __name__ == '__main__':
    main()

我应该为每个__{get,set}state__编写自定义的pickle功能(例如tf.Module)还是创建与.save相似的keras.Model方法?

1 个答案:

答案 0 :(得分:0)

您可以使用Saved Model Format保存自定义的tf.Module子类。

以下适用于Tensorflow 2.1:

import tensorflow as tf

class TestModule(tf.Module):
    def __init__(self, model):
        self.model = model


x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
model = tf.keras.Model(x, y)
module_1 = TestModule(model)

tf.saved_model.save(module_1, "./foo")

要重新加载:
imported = tf.saved_model.load("foo")

断言
module_1 == imported(或类似名称)将引发AssertionError,因为加载后,我们正在处理其他Tensorflow对象。但是,我们可以遍历模型的权重并逐个比较它们的权重:

original_weights = module_1.model.weights
imported_weights = imported.model.variables.weights

for weight_idx, _ in enumerate(original_weights):
  assert (
      original_weights[weight_idx].numpy() == imported_weights[weight_idx].numpy()
      ).all()