def predict(training_data, test_data):
# TODO: Train random forest classifier from given data
# Result should be an RDD with the prediction of the random forest for each
# test data point
RANDOM_SEED = 13579
RF_NUM_TREES = 3
RF_MAX_DEPTH = 4
RF_NUM_BINS = 32
model = RandomForest.trainClassifier(training_data, numClasses=2, categoricalFeaturesInfo={}, \
numTrees=RF_NUM_TREES, featureSubsetStrategy="auto", impurity="gini", \
maxDepth=RF_MAX_DEPTH, seed=RANDOM_SEED)
predictions = model.predict(test_data.map(lambda x: x.features))
labels_and_predictions = test_data.map(lambda x: x.label).zip(predictions)
return predictions
我遇到以下错误:
Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 122.0 failed 1 times, most recent failure: Lost task 0.0 in stage 122.0 (TID 226, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "C:\Spark\spark-2.4.3-bin-hadoop2.7\spark-2.4.3-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\worker.py", line 377, in main
File "C:\Spark\spark-2.4.3-bin-hadoop2.7\spark-2.4.3-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\worker.py", line 372, in process
File "C:\Spark\spark-2.4.3-bin-hadoop2.7\spark-2.4.3-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\serializers.py", line 393, in dump_stream
vs = list(itertools.islice(iterator, batch))
File "C:\Spark\spark-2.4.3-bin-hadoop2.7\spark-2.4.3-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<ipython-input-20-170be0983095>", line 12, in <lambda>
File "C:\Spark\spark-2.4.3-bin-hadoop2.7\spark-2.4.3-bin-hadoop2.7\python\lib\pyspark.zip\pyspark\mllib\linalg\__init__.py", line 483, in __getattr__
return getattr(self.array, item)
AttributeError: 'numpy.ndarray' object has no attribute 'features'
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:588)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:571)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.hasNext(SerDeUtil.scala:153)
at scala.collection.Iterator$class.foreach(Iterator.scala:891)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.foreach(SerDeUtil.scala:148)