我正在玩save
模型的load
和pyspark.ml.classification
函数。我创建了RandomForestClassifier
的实例,将值设置为几个参数,并称为分类器的save
方法。保存成功。那里没有问题。
from pyspark.ml.classification import RandomForestClassifier
# save
rf = RandomForestClassifier()
rf.setImpurity('entropy')
rf.setPredictionCol('predme')
rf.write().overwrite().save('rf_test')
然后我尝试将其加载回去,但我注意到它的参数没有保存之前设置的值。下面是我正在尝试的代码
# load
rf2 = RandomForestClassifier()
rf2.load('rf_test')
print(rf2.getImpurity()) # returns gini
print(rf2.getPredictionCol()) # returns prediction
我想我对这段代码应该如何工作以及实际上如何工作的理解有所不同。
我该怎么做才能以保存对象的方式找回对象?
编辑
我尝试了这里提到的方法。但这没有用。这就是我尝试过的
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier()
rf.setImpurity('entropy')
rf.setPredictionCol('predme')
rf.write().overwrite().save('rf_test')
rf2 = RandomForestClassifier
rf2.load('rf_test')
print(rf2.getImpurity())
返回了以下
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: getImpurity() missing 1 required positional argument: 'self'
答案 0 :(得分:1)
那不是您应该使用load
方法的方式。它是classmethod
,应在类对象(而不是实例)上调用以返回新对象:
rf2 = RandomForestClassifier.load('rf_test')
rf2.getImpurity()
从技术上讲,在实例上调用它也可以,但是不会修改调用者,而是返回一个新的独立对象:
rf2 = RandomForestClassifier().load('rf_test')
在实践中,应该避免这种构造。