我有用Python编写的Spark代码,该代码使用XGBoost模型进行预测。我面临的问题是代码具有“ for循环”功能,可以使用XGBoost模型在循环中预测不同的数据集并保存经过训练的模型。
代码运行正常(没有编码错误),但是在Model循环运行10-12次之后。它只是抛出以下错误并使我的Spark应用程序崩溃。重试甚至不起作用。
对我的工作进行一些故障排除。
->这不是数据问题。如果我使用相同的数据重新启动代码,则它将失败,因为它可以成功运行,并且在循环迭代后会失败
->我尝试在运行期间增加RAM,内核和监视CPU /执行器。我没看到任何问题。这不是资源问题。
Exception processing forecast: An error occurred while calling o17979.fit.
: ml.dmlc.xgboost4j.java.XGBoostError: XGBoostModel training failed
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:582)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:459)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:435)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.immutable.List.foreach(List.scala:392)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.immutable.List.map(List.scala:296)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:434)
at ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor.train(XGBoostRegressor.scala:190)
at ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor.train(XGBoostRegressor.scala:48)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:118)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
在同一指针上的任何指针都将很棒。提前致谢。
代码
#Loop through all clusters and train model
for cluster in range(num_clusters):
logger.info(f'Cluster : {cluster}')
logger.info(f'Run Start time : {datetime.now()}')
#SetModel Path and Model Name
model_path = 's3a://' + self.s3_bucket + '/'
model_name = 'Forecast' + 'Cluster_' + str(cluster)
logger.info(f'Model Path : {model_path}')
logger.info(f'Model Name : {model_name}')
#Filter based on Cluster and cache
forecast_vector_cluster_df = forecast_vector_df.where(f'cluster={cluster}')
forecast_vector_cluster_df.cache()
logger.info(f'Cluster Filter DF Count : {forecast_vector_cluster_df.count()}')
#Set Up param for XGBoost Model
xgbRegressor = XGBoostRegressor(**self.model_parammap) \
.setFeaturesCol("features") \
.setLabelCol("predict") \
.setPredictionCol(f"prediction_{cluster}")
logger.info('fitting model')
xgboostModel = xgbRegressor.fit(forecast_vector_cluster_df)
logger.info('saving model')
xgboostModel.write().overwrite().save(model_path + model_name)
答案 0 :(得分:0)
检查计算机上安装的Java版本
切换到java1.8可能会解决问题