我写了一个自定义变换器,如here所述。
当我的变压器作为第一步创建管道时,我能够训练(Logistic回归)模型进行分类。
但是,当我想用这样的管道执行交叉验证时:
from pyspark.ml.feature import HashingTF
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
sentenceDataFrame = sqlContext.createDataFrame([
(1.0, "Hi I heard about Spark"),
(1.0, "Spark is awesome"),
(0.0, "But there seems to be a problem"),
(0.0, "And I don't know why...")
], ["label", "sentence"])
tokenizer = NLTKWordPunctTokenizer(
inputCol="sentence", outputCol="words",
stopwords=set(nltk.corpus.stopwords.words('english')))
hasher = HashingTF(inputCol="words",outputCol="features")
lr = LogisticRegression()
pipeline = Pipeline(stages=[tokenizer,hasher,lr])
paramGrid = ParamGridBuilder().addGrid(lr.regParam, (0.01, 0.1))\
.addGrid(lr.tol, (1e-5, 1e-6))\
.build()
cv = CrossValidator(estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=BinaryClassificationEvaluator(),
numFolds=4)
model = cv.fit(sentenceDataFrame)
我收到以下错误:
Py4JJavaError: An error occurred while calling o59.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 3.0 failed 1 times, most recent failure: Lost task 0.0 in stage 3.0 (TID 3, localhost): org.apache.spark.SparkException: Can only zip RDDs with same number of elements in each partition
at org.apache.spark.rdd.RDD$$anonfun$zip$1$$anonfun$apply$27$$anon$1.hasNext(RDD.scala:812)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
at org.apache.spark.storage.MemoryStore.unrollSafely(MemoryStore.scala:276)
at org.apache.spark.CacheManager.putInBlockManager(CacheManager.scala:171)
at org.apache.spark.CacheManager.getOrCompute(CacheManager.scala:78)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:242)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:35)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:277)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:244)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:63)
at org.apache.spark.scheduler.Task.run(Task.scala:70)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1273)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1264)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1263)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1263)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:730)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:730)
at scala.Option.foreach(Option.scala:236)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:730)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1457)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1418)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
Python stacktrace:
Py4JJavaError Traceback (most recent call last)
<ipython-input-11-780e2ee6bae8> in <module>()
22 numFolds=4)
23
---> 24 model = cv.fit(sentenceDataFrame)
~/spark-1.4.1-bin-hadoop2.6/python/pyspark/ml/pipeline.pyc in fit(self, dataset, params)
63 return self.copy(params)._fit(dataset)
64 else:
---> 65 return self._fit(dataset)
66 else:
67 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
~/spark-1.4.1-bin-hadoop2.6/python/pyspark/ml/tuning.pyc in _fit(self, dataset)
220 train = df.filter(~condition)
221 for j in range(numModels):
--> 222 model = est.fit(train, epm[j])
223 # TODO: duplicate evaluator to take extra params from input
224 metric = eva.evaluate(model.transform(validation, epm[j]))
~/spark-1.4.1-bin-hadoop2.6/python/pyspark/ml/pipeline.pyc in fit(self, dataset, params)
61 elif isinstance(params, dict):
62 if params:
---> 63 return self.copy(params)._fit(dataset)
64 else:
65 return self._fit(dataset)
~/spark-1.4.1-bin-hadoop2.6/python/pyspark/ml/pipeline.pyc in _fit(self, dataset)
196 dataset = stage.transform(dataset)
197 else: # must be an Estimator
--> 198 model = stage.fit(dataset)
199 transformers.append(model)
200 if i < indexOfLastEstimator:
~/spark-1.4.1-bin-hadoop2.6/python/pyspark/ml/pipeline.pyc in fit(self, dataset, params)
63 return self.copy(params)._fit(dataset)
64 else:
---> 65 return self._fit(dataset)
66 else:
67 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
~/spark-1.4.1-bin-hadoop2.6/python/pyspark/ml/wrapper.pyc in _fit(self, dataset)
129
130 def _fit(self, dataset):
--> 131 java_model = self._fit_java(dataset)
132 return self._create_model(java_model)
133
~/spark-1.4.1-bin-hadoop2.6/python/pyspark/ml/wrapper.pyc in _fit_java(self, dataset)
126 """
127 self._transfer_params_to_java()
--> 128 return self._java_obj.fit(dataset._jdf)
129
130 def _fit(self, dataset):
~/spark-1.4.1-bin-hadoop2.6/python/lib/py4j-0.8.2.1-src.zip/py4j/java_gateway.py in __call__(self, *args)
536 answer = self.gateway_client.send_command(command)
537 return_value = get_return_value(answer, self.gateway_client,
--> 538 self.target_id, self.name)
539
540 for temp_arg in temp_args:
~/spark-1.4.1-bin-hadoop2.6/python/lib/py4j-0.8.2.1-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
298 raise Py4JJavaError(
299 'An error occurred while calling {0}{1}{2}.\n'.
--> 300 format(target_id, '.', name), value)
301 else:
302 raise Py4JError(
我通过事先转换数据框来解决这个错误,即将我的变压器移出管道。 但我真的希望将所有步骤保留在处理管道中,因此我可以在对没有任何先前步骤的未见数据进行分类时使用它,并且还能够调整特征提取参数。所以任何帮助都表示赞赏。