Pyspark.ml-加载模型和管道时出错

时间:2020-10-14 15:45:33

标签: apache-spark pyspark spark3

我想将训练有素的pyspark模型(或管道)导入pyspark脚本。我像这样训练了决策树模型:

from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer

# Create assembler and labeller for spark.ml format preperation
assembler = VectorAssembler(inputCols = requiredFeatures, outputCol = 'features')
label_indexer = StringIndexer(inputCol='measurement_status', outputCol='indexed_label')

# Apply transformations
eq_df_labelled = label_indexer.fit(eq_df).transform(eq_df)
eq_df_labelled_featured = assembler.transform(eq_df_labelled)

# Split into training and testing datasets
(training_data, test_data) = eq_df_labelled_featured.randomSplit([0.75, 0.25])

# Create a decision tree algorithm
dtree = DecisionTreeClassifier(
    labelCol ='indexed_label',
    featuresCol = 'features',
    maxDepth = 5,
    minInstancesPerNode=1,
    impurity = 'gini',
    maxBins=32,
    seed=None
)

# Fit classifier object to training data
dtree_model = dtree.fit(training_data)

# Save model to given directory
dtree_model.save("models/dtree")

以上所有代码都可以正常运行。问题是,当我尝试使用以下方法加载该模型(在相同或另一个pyspark应用程序上)时:

from pyspark.ml.classification import DecisionTreeClassifier

imported_model = DecisionTreeClassifier()
imported_model.load("models/dtree")

我收到以下错误:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-4-b283bc2da75f> in <module>
      2 
      3 imported_model = DecisionTreeClassifier()
----> 4 imported_model.load("models/dtree")
      5 
      6 #lodel = DecisionTreeClassifier.load("models/dtree-test/")

~/.local/lib/python3.6/site-packages/pyspark/ml/util.py in load(cls, path)
    328     def load(cls, path):
    329         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 330         return cls.read().load(path)
    331 
    332 

~/.local/lib/python3.6/site-packages/pyspark/ml/util.py in load(self, path)
    278         if not isinstance(path, basestring):
    279             raise TypeError("path should be a basestring, got type %s" % type(path))
--> 280         java_obj = self._jread.load(path)
    281         if not hasattr(self._clazz, "_from_java"):
    282             raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"

~/.local/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1303         answer = self.gateway_client.send_command(command)
   1304         return_value = get_return_value(
-> 1305             answer, self.gateway_client, self.target_id, self.name)
   1306 
   1307         for temp_arg in temp_args:

~/.local/lib/python3.6/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
    126     def deco(*a, **kw):
    127         try:
--> 128             return f(*a, **kw)
    129         except py4j.protocol.Py4JJavaError as e:
    130             converted = convert_exception(e.java_exception)

~/.local/lib/python3.6/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o39.load.
: java.lang.UnsupportedOperationException: empty collection
    at org.apache.spark.rdd.RDD.$anonfun$first$1(RDD.scala:1439)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:388)
    at org.apache.spark.rdd.RDD.first(RDD.scala:1437)
    at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:587)
    at org.apache.spark.ml.util.DefaultParamsReader.load(ReadWrite.scala:465)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

我之所以采用这种方法,是因为它也无法使用Pipeline对象。有什么想法吗?

更新

我已经意识到,仅当我使用Spark集群(一个主控,两个使用Spark的独立集群管理器的工作程序)时,才会发生此错误。如果我这样设置Spark Session(将主服务器设置为本地主机):

spark = SparkSession\
    .builder\
    .config(conf=conf)\
    .appName("MachineLearningTesting")\
    .master("local[*]")\
    .getOrCreate()

我没有收到以上错误。

此外,我使用的是Spark 3.0.0,难道是Spark 3中导入和导出的模型仍然存在bug?

1 个答案:

答案 0 :(得分:0)

有两个问题:

    必须在群集中的所有节点之间启用
  1. SSH身份验证通信。即使我的Spark群集中的所有节点都在同一网络中,也只有主服务器对工作人员进行SSH身份验证,反之亦然。

  2. 该模型必须可用于集群中的所有节点。这听起来似乎很明显,但是我认为模型文件只需要对主服务器可用,然后再将其传播到工作节点。换句话说,当您像这样加载模型时:

public render(){
    return(
        <Layout>
            <div>I am some content!</div>
        </Layout>
    );
}

文件from pyspark.ml.classification import DecisionTreeClassifier imported_model = DecisionTreeClassifier() imported_model.load("models/dtree") 必须存在于群集中的每台计算机上。这使我了解到,在生产环境中,可以通过外部共享文件系统访问模型。

这两个步骤解决了我将pyspark模型加载到集群上运行的Spark应用程序中的问题。