API = ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
val xgbParam = Map(“ eta”-> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2)
我正在运行一项作业,直到API的线程数等于为Spark设置的num_worker为止。
因此,在master = local模式下,当我执行--master local [n]并将该API的num_worker设置为与n相同的值时,就可以使用。
但是,在群集中,我不知道要控制哪个参数精确地调用处理线程数的调用。我尝试过-
1) spark.task.cpus
2) spark.default.parallelism
3) executor cores
但是,它们都不起作用,这个问题的特殊之处在于,如果不满足上述条件,则在分发XGBoost模型时就会停顿下来。
我的代码如下,它可以在本地模式下工作,但不能在群集中工作,有帮助吗?
代码:
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
val schema = new StructType(Array(
StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true),
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
val rawInput = spark.read.schema(schema).csv("file:///appdata/bblite-data/iris.csv")
import org.apache.spark.ml.feature.StringIndexer
val stringIndexer = new StringIndexer().
setInputCol("class").
setOutputCol("classIndex").
fit(rawInput)
val labelTransformed = stringIndexer.transform(rawInput).drop("class")
import org.apache.spark.ml.feature.VectorAssembler
val vectorAssembler = new VectorAssembler().
setInputCols(Array("sepal length", "sepal width", "petal length", "petal width")).
setOutputCol("features")
val xgbInput = vectorAssembler.transform(labelTransformed).select("features", "classIndex")
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
val xgbParam = Map("eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2)
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
val xgbClassificationModel = xgbClassifier.fit(xgbInput)