我想将使用FedAvg算法训练的TensorFlow联合模型保存为Keras / .h5模型。我找不到有关此文档,并且想知道如何完成。 另外,如果可能的话,我想同时访问聚合的服务器模型和客户端模型。
我用于训练联邦模型的代码如下:
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(segment_size,num_input_channels)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=400, activation='relu'),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Dense(units=100, activation='relu'),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Dense(activityCount, activation='softmax'),
])
return tff.learning.from_keras_model(
model,
dummy_batch=batch,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learningRate))
def evaluate(num_rounds=communicationRound):
state = trainer.initialize()
roundMetrics = []
evaluation = tff.learning.build_federated_evaluation(model_fn)
for round_num in range(num_rounds):
t1 = time.time()
state, metrics = trainer.next(state, train_data)
t2 = time.time()
test_metrics = evaluation(state.model, train_data)
roundMetrics.append('round {:2d}, metrics={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
roundMetrics.append("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
roundMetrics.append('round time={}'.format(t2 - t1))
print('round {:2d}, accuracy={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
print("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
print('round time={}'.format(t2 - t1))
outF = open(filepath+'stats'+architectureType+'.txt', "w")
for line in roundMetrics:
outF.write(line)
outF.write("\n")
outF.close()
答案 0 :(得分:3)
大致上,我们将使用save_checkpoint / load_checkpoint方法。特别是,您可以实例化FileCheckpointManager,并要求它直接(几乎)保存状态。
您的示例中的state是tff.python.common_libs.anonymous_tuple.AnonymousTuple(IIRC)的实例,该实例与tf.convert_to_tensor不兼容,这是save_checkpoint所需并在其文档字符串中声明的。 TFF研究代码中经常使用的通用解决方案是引入Python attrs类,以便在返回状态后立即将其从匿名元组转换为
假设以上所述,以下草图应适用:
# state assumed an anonymous tuple, previously created
# N some integer
ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)
要从此检查点恢复,请随时致电:
state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
ServerState.from_anon_tuple(state))
需要注意的一件事:上面链接的代码指针通常在tff.python.research ...中,而pip包中不包括该代码指针;因此,获取它们的首选方法是将代码放入您自己的项目中,或者拉下存储库并从源代码进行构建。