如何在打击代码中保存模型
如果您想运行代码,请访问https://github.com/tensorflow/federated 并下载federated_learning_for_image_classification.ipynb。
如果您在federated_learning_for_image_classification.ipynb教程中告诉我如何保存联合学习的模型,我将不胜感激。
machine
答案 0 :(得分:0)
大致上,我们将使用对象here及其save_checkpoint
/ load_checkpoint
方法。特别是,您可以实例化一个FileCheckpointManager
,并要求它直接(几乎)保存state
。
state
是tff.python.common_libs.anonymous_tuple.AnonymousTuple
所需要的tf.convert_to_tensor
(IIRC)实例,它与save_checkpoint
不兼容,并在其文档字符串中声明。 TFF研究代码中经常使用的通用解决方案是引入Python attr
的类,以在返回状态后立即将其从匿名元组转换为其他示例。
假设以上所述,以下草图应适用:
# state assumed an anonymous tuple, previously created
# N some integer
ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)
要从此检查点恢复,可以随时致电:
state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
ServerState.from_anon_tuple(state))
要注意的一件事:上面链接的代码指针通常在tff.python.research...
中,不包含在pip包中;因此,获取它们的首选方法是将代码放入您自己的项目中,或者拉下存储库并从源代码进行构建。
感谢您对TFF的关注!
答案 1 :(得分:0)
model.save_weights是否适用于此问题?我知道FileCheckpointManager会做得更完整(捕获每轮权重),但是我想就最终的联邦平均模型而言,参数空间应该在save_weights中可用。
答案 2 :(得分:0)
您可以在
中使用FileCheckpointManager
类
但是,TFF 的发布版本 (v0.18.0) 不支持此类。您应该将此文件复制到您的项目目录中,以便您可以导入 FileCheckpointManager
。
'''
# PASTE YOUR CODE BEFORE HERE
# Required:
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()
'''
from checkpoint_manager import FileCheckpointManager
fcm = FileCheckpointManager('checkpoint/')
# Save model
round_num = 110 # It depends on rounds you have trained
fcm.save_checkpoint(state, round_num)
# Load model
state, round_num = fcm.load_latest_checkpoint(state)
state, metrics = iterative_process.next(state, federated_train_data)