如何控制联合框架的验证数据

时间:2019-06-25 15:13:43

标签: tensorflow-federated

我正在尝试指定通过联合框架传递给每个客户端以进行培训/验证的验证数据。我知道tensorflow-federated会从每个客户端的数据集中随机抽取样本并对其进行验证。但是,如果我的数据(在一个子集中)非常相关,如何在TFF框架中为每个客户端指定验证数据集? 您认为对数据进行混排在这里有意义吗? (例如,使用: DS.repeat(FL_rpt).shuffle(FL_shuf).batch(FL_batch)) 如果是这样,对shuffle_buffer的大小有什么建议吗?

在keras训练中,我们有以下内容可以在Set A上训练模型并在Set B上验证训练:

model.fit(InA,OutA, validation_data=(In_valid_B,Out_valid_B),batch_size=100,epochs=100)

我们如何对联合框架执行相同的操作?

1 个答案:

答案 0 :(得分:0)

这可能是在模拟过程中在外部Python循环中编写的。当前的API在单个回合中都没有评估和培训的概念。

如果使用TFF中包含的模拟数据集(例如tff.simulation.datasets下的模拟数据集),则它们将包含训练/测试拆分,从而使此操作变得容易。每个返回一个2元组的tff.simulation.ClientData对象,一个测试和一个训练ClientData。测试和训练都具有相同的ClientData.client_id列表,但是tf.data.Dataset返回的create_tf_dataset_for_client(client_id)将具有不相交的示例集。

换句话说,训练和测试拆分是针对用户示例的,而不是针对用户示例的。

联合培训和联合评估循环可能类似于:

train_data, test_data = tff.simulation.datasets.shakespeare.load_data()

federated_average = tff.learning.build_federated_averaging_process(model_fn, ...)
federated_eval = tff.learning.build_federated_evaluation(model_fn)

state = federated_average.initialize()

for _ in range(NUM_ROUNDS):
  participating_clients = numpy.random.choice(train_data.client_ids, size=5)

  # Run a training pass
  clients_train_datasets = [
    train_data.create_tf_dataset_for_client(c) for c in participating_clients
  ]
  state, train_metrics = federated_average.next(state, client_train_datasets)

  # Run an evaluation pass
  client_eval_datasets = [
    test_data.create_tf_dataset_for_client(c) for c in participating_clients
  ]
  eval_metrics = federated_eval(state.model, client_eval_datasets)