我有如下方法:
def importFrom(module, name):
module = importlib.import_module(module)
return getattr(module, name)
然后按以下方式使用它:
def imputation_LR (df, name):
reg = importFrom('sklearn.linear_model', name)
reg.fit(X_train, y_train)
然后按如下方式命名:
data = imputation_LR (data, 'LinearRegression')
并得到以下错误:
reg.fit(X_train, y_train)
TypeError: fit() missing 1 required positional argument: 'y'
我觉得这是关于LinearRegression
/ LinearRegression()
的事情,但无法弄清。
谢谢。
答案 0 :(得分:0)
您需要创建一个LinearRegression
对象。出现错误的原因是因为fit
是实例方法(所有方法都将self
作为第一个参数传递)。因此,函数接口实际上是fit(self, x, y)
。 self
是一个特殊变量,在调用实例方法时由对象引用自动传递。尝试此操作(请注意,我添加了()
,请参见嵌入式注释):
def imputation_LR (df, name):
reg = importFrom('sklearn.linear_model', name)() # Note I added () here
reg.fit(X_train, y_train)
请注意,它位于此处记录的API的“方法”部分下:http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html。
如有疑问,请阅读文档。
P.S。尚不清楚为什么您要尝试将模块作为字符串导入。您可以使用要获取的包/模块的名称进行导入:
from sklearn.linear_model import LinearRegression
reg = LinearRegression()
...