我正在研究使用tensorflow联合API的federated_learning_for_image_classification.ipynb。
在示例中,我可以检查每个模拟的客户训练的准确性,损失和总准确性,总损失。
但是没有检查点文件。
我要制作每个客户端检查点文件和全部检查点文件。
然后比较客户端参数变量和总参数变量。
有人可以在federated_learning_for_image_classification.ipynb示例中帮助我制作检查点文件吗?
答案 0 :(得分:1)
要问的一个问题是,您是否要比较变量Tem内的Tem(作为联合计算的一部分)还是事后/外部TFF(在Python中进行分析)。
修改tff.learning.build_federated_averaging_process
执行的tff.utils.IterativeProcess
构造可能是一个不错的方法。实际上,我建议在tensorflow_federated/python/research/simple_fedavg/simple_fedavg.py
上分叉GitHub上的简化实现,而不是深入研究tff.learning
。
更改将the line从客户端更新到tff.fedetated_mean
的tff.federated_collect
将会给出所有客户端模型的列表,然后可以将其与全局模型进行比较。
示例:
client_deltas = tff.federated_collect(client_outputs.weights_delta)
@tff.tf_computation(server_state.model.type_signature,
client_deltas.type_signature)
def compare_deltas_to_global(global_model, deltas):
for delta in deltas:
# do something with delta vs global_model
tff.federated_apply(compare_deltas_to_global, (server_state.model, client_deltas))