如何将训练有素的TensorFlow联合模型另存为.h5模型?

时间:2020-03-26 21:23:17

标签: tensorflow-federated

我想将使用FedAvg算法训练的TensorFlow联合模型保存为Keras / .h5模型。我找不到有关此文档,并且想知道如何完成。 另外,如果可能的话,我想同时访问聚合的服务器模型和客户端模型。

我用于训练联邦模型的代码如下:

def model_fn():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(segment_size,num_input_channels)),
      tf.keras.layers.Flatten(), 
      tf.keras.layers.Dense(units=400, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(units=100, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(activityCount, activation='softmax'),
    ])
    return tff.learning.from_keras_model(
      model,
      dummy_batch=batch,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learningRate))

def evaluate(num_rounds=communicationRound):
  state = trainer.initialize()
  roundMetrics = []
  evaluation = tff.learning.build_federated_evaluation(model_fn)

  for round_num in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    test_metrics = evaluation(state.model, train_data)

    roundMetrics.append('round {:2d}, metrics={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    roundMetrics.append("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    roundMetrics.append('round time={}'.format(t2 - t1))
    print('round {:2d}, accuracy={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    print("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    print('round time={}'.format(t2 - t1))
  outF = open(filepath+'stats'+architectureType+'.txt', "w")
  for line in roundMetrics:
    outF.write(line)
    outF.write("\n")
  outF.close()

1 个答案:

答案 0 :(得分:3)

大致上,我们将使用save_checkpoint / load_checkpoint方法。特别是,您可以实例化FileCheckpointManager,并要求它直接(几乎)保存状态。

您的示例中的

state是tff.python.common_libs.anonymous_tuple.AnonymousTuple(IIRC)的实例,该实例与tf.convert_to_tensor不兼容,这是save_checkpoint所需并在其文档字符串中声明的。 TFF研究代码中经常使用的通用解决方案是引入Python attrs类,以便在返回状态后立即将其从匿名元组转换为

假设以上所述,以下草图应适用:

# state assumed an anonymous tuple, previously created
# N some integer 

ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)

要从此检查点恢复,请随时致电:

state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
    ServerState.from_anon_tuple(state))

需要注意的一件事:上面链接的代码指针通常在tff.python.research ...中,而pip包中不包括该代码指针;因此,获取它们的首选方法是将代码放入您自己的项目中,或者拉下存储库并从源代码进行构建。