如何在联邦的TensorFlow中保存模型

时间:2019-11-10 04:52:32

标签: tensorflow-federated

如何在打击代码中保存模型

如果您想运行代码,请访问https://github.com/tensorflow/federated 并下载federated_learning_for_image_classification.ipynb。

如果您在federated_learning_for_image_classification.ipynb教程中告诉我如何保存联合学习的模型,我将不胜感激。

machine

3 个答案:

答案 0 :(得分:0)

大致上,我们将使用对象here及其save_checkpoint / load_checkpoint方法。特别是,您可以实例化一个FileCheckpointManager,并要求它直接(几乎)保存state

在您的示例中,

statetff.python.common_libs.anonymous_tuple.AnonymousTuple所需要的tf.convert_to_tensor(IIRC)实例,它与save_checkpoint不兼容,并在其文档字符串中声明。 TFF研究代码中经常使用的通用解决方案是引入Python attr的类,以在返回状态后立即将其从匿名元组转换为其他示例。

假设以上所述,以下草图应适用:

# state assumed an anonymous tuple, previously created
# N some integer 

ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)

要从此检查点恢复,可以随时致电:

state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
    ServerState.from_anon_tuple(state))

要注意的一件事:上面链接的代码指针通常在tff.python.research...中,不包含在pip包中;因此,获取它们的首选方法是将代码放入您自己的项目中,或者拉下存储库并从源代码进行构建。

感谢您对TFF的关注!

答案 1 :(得分:0)

model.save_weights是否适用于此问题?我知道FileCheckpointManager会做得更完整(捕获每轮权重),但是我想就最终的联邦平均模型而言,参数空间应该在save_weights中可用。

答案 2 :(得分:0)

您可以在

中使用 FileCheckpointManager

https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/simulation/checkpoint_manager.py

但是,TFF 的发布版本 (v0.18.0) 不支持此类。您应该将此文件复制到您的项目目录中,以便您可以导入 FileCheckpointManager


'''
# PASTE YOUR CODE BEFORE HERE

# Required:
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()
'''

from checkpoint_manager import FileCheckpointManager

fcm = FileCheckpointManager('checkpoint/')

# Save model

round_num = 110 # It depends on rounds you have trained
fcm.save_checkpoint(state, round_num)

# Load model

state, round_num = fcm.load_latest_checkpoint(state)
state, metrics = iterative_process.next(state, federated_train_data)