tensorflow联合学习检查点

时间:2019-10-05 11:45:30

标签: tensorflow tensorflow-federated

我正在研究使用tensorflow联合API的federated_learning_for_image_classification.ipynb。

在示例中,我可以检查每个模拟的客户训练的准确性,损失和总准确性,总损失。

但是没有检查点文件。

我要制作每个客户端检查点文件和全部检查点文件。

然后比较客户端参数变量和总参数变量。

有人可以在federated_learning_for_image_classification.ipynb示例中帮助我制作检查点文件吗?

1 个答案:

答案 0 :(得分:1)

要问的一个问题是,您是否要比较变量Tem内的Tem(作为联合计算的一部分)还是事后/外部TFF(在Python中进行分析)。

修改tff.learning.build_federated_averaging_process执行的tff.utils.IterativeProcess构造可能是一个不错的方法。实际上,我建议在tensorflow_federated/python/research/simple_fedavg/simple_fedavg.py上分叉GitHub上的简化实现,而不是深入研究tff.learning

更改将the line从客户端更新到tff.fedetated_meantff.federated_collect将会给出所有客户端模型的列表,然后可以将其与全局模型进行比较。

示例:

client_deltas = tff.federated_collect(client_outputs.weights_delta)

@tff.tf_computation(server_state.model.type_signature,
                    client_deltas.type_signature)
def compare_deltas_to_global(global_model, deltas):
  for delta in deltas:
    # do something with delta vs global_model 

tff.federated_apply(compare_deltas_to_global, (server_state.model, client_deltas))