我正在尝试让我的KerasRegressor模型与TFF框架一起使用。但似乎“ tff.learning.from_compiled_keras_model”不接受它,对吗?我的主要目的是在回归问题和分类问题上区分/评估联盟。
我尝试使用“ tf.keras.wrappers.scikit_learn.KerasRegressor”作为回归问题的一种手段,如下代码所示。
这是我与代码相关的部分:
def create_SK_model():
modelF = create_SGD_model()
modelF.compile(loss=tf.keras.losses.MSE,optimizer=tf.keras.optimizers.SGD(lr=learn_rate))
return modelF
def create_Reg_model():
modelF_Reg = tf.keras.wrappers.scikit_learn.KerasRegressor(build_fn = create_SK_model,nb_epoch=SNN_epoch, batch_size=SNN_batch_size)
return modelF_Reg
def create_Class_model():
modelF_Reg = tf.keras.wrappers.scikit_learn.KerasClassifier(build_fn = create_SK_model,nb_epoch=SNN_epoch, batch_size=SNN_batch_size)
return modelF_Class
def create_Single_model():
if Use_RegClas:
if Use_Regressor:
return create_Reg_model()
elif Use_Classification:
return create_Class_model()
else:
return create_SK_model()
def model_fn_Federated():
if Use_RegClas:
if Use_Regressor:
return tff.learning.from_compiled_keras_model(create_Reg_model,sample_batch)
elif Use_Classification:
return tff.learning.from_compiled_keras_model(create_Class_model(),sample_batch)
elif Use_FLAveraging:
return tff.learning.from_compiled_keras_model(create_SK_model(),sample_batch)
else:
return tff.learning.from_keras_model(create_SGD_model(),sample_batch,loss=tf.keras.losses.MSE)
................... some other code ..................
if Use_FLAveraging:
trainer_Itr_Process = tff.learning.build_federated_averaging_process(model_fn_Federated,server_optimizer_fn=(lambda : tf.keras.optimizers.SGD(learning_rate=learn_rate)),client_weight_fn=None)
else:
trainer_Itr_Process = tff.learning.build_federated_sgd_process(model_fn_Federated,server_optimizer_fn=(lambda : tf.keras.optimizers.SGD(learning_rate=learn_rate)),client_weight_fn=None)
我的主要问题是如何将回归/分类问题纳入联邦TF框架。我尝试了上面的实现,对吗?错误?请指教。
基于上述实现,我得到以下错误:
.....
py_typecheck.check_type(keras_model, tf.keras.Model)
File "/home/..../.local/lib/python3.6/site-packages/tensorflow_federated/python/common_libs/py_typecheck.py", line 48, in check_type
type_string(type_spec), type_string(type(target))))
TypeError: Expected tensorflow.python.keras.engine.training.Model, found function.