XGBoost模型训练失败

时间:2019-12-09 06:39:30

标签: python scala apache-spark pyspark xgboost

我想在python中使用dmlc xgboost。为此,我在scala中编写了一个包装程序,使用sbt编译并打包了一罐代码,并将其放在$SPARK_HOME/jars文件夹中。

This是我编写的Scala代码,here是sbt生成的jar。

使用下面的python代码,我正在尝试访问XGBClassifier

from pyspark.sql import SparkSession
from pyspark.sql.types import *    

sc = spark.sparkContext
scala_class = sc._jvm.com.scalapyspark.XGBClassifier

# a small wrapper class to instantiate the scala_class and take care of conversions
class XGBClassifier:
    def __init__(self):
        self.xgb = scala_class()
        self.xgb.createObject(json.dumps({}))
    def setParams(self, params):
        self.xgb.setParams(json.dumps(params))
    def fit(self, df):
        self.xgb.fit(df._jdf)

params = {
    "labelCol":"high_income",
    "featuresCol":"feature_vector",
    "alpha":0.1,
    "colsampleBylevel":0.9,
}

xgb = XGBClassifier()
xgb.setParams(params)

到目前为止,一切似乎还不错,即未收到任何错误或警告。但是,当我尝试使用以适合数据框时,会出现错误。

>>> xgb.fit(df)

Tracker started, with env={}
[Stage 4:=============================>                             (1 + 1) / 2][13:00:29] /Users/nanzhu/code/xgboost/src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 88 extra nodes, 0 pruned nodes, max_depth=6
[0] train-error:0.172416
19/12/09 13:00:30 ERROR RabitTracker: Uncaught exception thrown by worker:
java.lang.InterruptedException
    at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedInterruptibly(AbstractQueuedSynchronizer.java:998)
    at java.util.concurrent.locks.AbstractQueuedSynchronizer.acquireSharedInterruptibly(AbstractQueuedSynchronizer.java:1304)
    at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:206)
    at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:222)
    at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:157)
    at org.apache.spark.util.ThreadUtils$.awaitReady(ThreadUtils.scala:243)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:728)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
    at org.apache.spark.rdd.RDD$$anonfun$foreachPartition$1.apply(RDD.scala:935)
    at org.apache.spark.rdd.RDD$$anonfun$foreachPartition$1.apply(RDD.scala:933)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.RDD.foreachPartition(RDD.scala:933)
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$4$$anon$1.run(XGBoost.scala:287)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 8, in fit
  File "/usr/local/lib/python3.7/site-packages/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/local/lib/python3.7/site-packages/pyspark/sql/utils.py", line 63, in deco
    return f(*a, **kw)
  File "/usr/local/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value
    format(target_id, ".", name), value)
py4j.protocol.Py4JJavaError: An error occurred while calling o54.fit.
: ml.dmlc.xgboost4j.java.XGBoostError: XGBoostModel training failed
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:364)
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$4.apply(XGBoost.scala:294)
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$4.apply(XGBoost.scala:256)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
    at scala.collection.immutable.List.map(List.scala:296)
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:255)
    at ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier.train(XGBoostClassifier.scala:200)
    at ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier.train(XGBoostClassifier.scala:48)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:118)
    at com.scalapyspark.XGBClassifier.fit(wrapped_xgboost.scala:149)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

>>> 
>>> 

我搜索了此错误的原因,但找不到我可以使用的任何东西。

如何解决此实现?

0 个答案:

没有答案