Tensorflow联合上的KerasRegressor

时间:2019-06-23 11:20:37

标签: tensorflow-federated

我正在尝试让我的KerasRegressor模型与TFF框架一起使用。但似乎“ tff.learning.from_compiled_keras_model”不接受它,对吗?我的主要目的是在回归问题和分类问题上区分/评估联盟。

我尝试使用“ tf.keras.wrappers.scikit_learn.KerasRegressor”作为回归问题的一种手段,如下代码所示。

这是我与代码相关的部分:

def create_SK_model():
    modelF = create_SGD_model()
    modelF.compile(loss=tf.keras.losses.MSE,optimizer=tf.keras.optimizers.SGD(lr=learn_rate))
    return modelF

def create_Reg_model():
    modelF_Reg = tf.keras.wrappers.scikit_learn.KerasRegressor(build_fn = create_SK_model,nb_epoch=SNN_epoch, batch_size=SNN_batch_size)

    return modelF_Reg

def create_Class_model():
    modelF_Reg = tf.keras.wrappers.scikit_learn.KerasClassifier(build_fn = create_SK_model,nb_epoch=SNN_epoch, batch_size=SNN_batch_size)
    return modelF_Class

def create_Single_model():
    if Use_RegClas:
        if Use_Regressor:
            return create_Reg_model()
        elif Use_Classification:
            return create_Class_model()
    else:
        return create_SK_model()
def model_fn_Federated():
    if Use_RegClas:
        if Use_Regressor:
            return tff.learning.from_compiled_keras_model(create_Reg_model,sample_batch)
        elif Use_Classification:
            return tff.learning.from_compiled_keras_model(create_Class_model(),sample_batch)
    elif Use_FLAveraging:
        return tff.learning.from_compiled_keras_model(create_SK_model(),sample_batch)
    else:
        return tff.learning.from_keras_model(create_SGD_model(),sample_batch,loss=tf.keras.losses.MSE)


................... some other code ..................

if Use_FLAveraging:
    trainer_Itr_Process = tff.learning.build_federated_averaging_process(model_fn_Federated,server_optimizer_fn=(lambda : tf.keras.optimizers.SGD(learning_rate=learn_rate)),client_weight_fn=None)
else:
    trainer_Itr_Process = tff.learning.build_federated_sgd_process(model_fn_Federated,server_optimizer_fn=(lambda : tf.keras.optimizers.SGD(learning_rate=learn_rate)),client_weight_fn=None)


我的主要问题是如何将回归/分类问题纳入联邦TF框架。我尝试了上面的实现,对吗?错误?请指教。

基于上述实现,我得到以下错误:

.....
    py_typecheck.check_type(keras_model, tf.keras.Model)
  File "/home/..../.local/lib/python3.6/site-packages/tensorflow_federated/python/common_libs/py_typecheck.py", line 48, in check_type
    type_string(type_spec), type_string(type(target))))
TypeError: Expected tensorflow.python.keras.engine.training.Model, found function.

0 个答案:

没有答案