如何在联邦的Tensorflow中打印本地输出?

时间:2019-05-28 02:13:44

标签: tensorflow-federated

我想在tensorflow联合教程https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification中打印客户端的本地输出。我该怎么办?

1 个答案:

答案 0 :(得分:0)

如果您只想要汇总中的值的列表(例如,放入tff.federated_mean中),一种选择是向aggregate_mnist_metrics_across_clients()添加其他输出,以包括使用{{3}计算的指标}。

这可能看起来像:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return {
      'num_examples': tff.federated_sum(metrics.num_examples),
      'loss': tff.federated_mean(metrics.loss, metrics.num_examples),
      'accuracy': tff.federated_mean(metrics.accuracy, metrics.num_examples),
      'per_client/num_examples': tff.federated_collect(metrics.num_examples),
      'per_client/loss': tff.federated_collect(metrics.loss),
      'per_client/accuracy': tff.federated_collect(metrics.accuracy),
  }

稍后在运行计算时将打印出几个单元格:

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))

round  1, metrics=<...,per_client/accuracy=[0.14516129, 0.10642202, 0.13972603],per_client/loss=[3.2409852, 3.417463, 2.9516447],per_client/num_examples=[930.0, 1090.0, 730.0]>

请注意:如果您想了解特定客户的价值,则无意这样做。通过设计,TFF的语言故意避免了客户身份的概念。希望避免使客户成为可寻址对象。