我有一个使用pyspark训练和测试机器学习模型的python代码,代码如下:
def build_model(self, database_url_training, database_url_test, label):
training_file = self.file_processor(database_url_training)
training_file = training_file.withColumnRenamed(label, "label")
pre_processing_text = list()
assembler_columns_input = []
training_string_fields = self.fields_from_dataframe(
training_file, True)
for column in training_string_fields:
tokenizer = Tokenizer(
inputCol=column, outputCol=(column + "_words"))
pre_processing_text.append(tokenizer)
hashing_tf_output_column_name = column + "_features"
hashing_tf = HashingTF(
inputCol=tokenizer.getOutputCol(),
outputCol=hashing_tf_output_column_name)
pre_processing_text.append(hashing_tf)
assembler_columns_input.append(hashing_tf_output_column_name)
training_number_fields = self.fields_from_dataframe(
training_file, False)
for column in training_number_fields:
if(column != label):
assembler_columns_input.append(column)
assembler = VectorAssembler(
inputCols=assembler_columns_input,
outputCol="features")
assembler.setHandleInvalid("skip")
logistic_regression = LogisticRegression(maxIter=10)
pipeline = Pipeline(
stages=[*pre_processing_text, assembler, logistic_regression])
param_grid = ParamGridBuilder().build()
cross_validator = CrossValidator(
estimator=pipeline,
estimatorParamMaps=param_grid,
evaluator=BinaryClassificationEvaluator(),
numFolds=2)
cross_validator_model = cross_validator.fit(training_file)
test_file = self.file_processor(database_url_test)
prediction = cross_validator_model.transform(test_file)
for row in prediction.collect():
print(row, flush=True)
我正在使用较小的数据集进行训练,训练数据集:
PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Thayer)",female,38,1,0,PC 17599,71.2833,C85,C
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35,1,0,113803,53.1,C123,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
测试数据集:
PassengerId,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
892,3,"Kelly, Mr. James",male,34.5,0,0,330911,7.8292,,Q
893,3,"Wilkes, Mrs. James (Ellen Needs)",female,47,1,0,363272,7,,S
894,2,"Myles, Mr. Thomas Francis",male,62,0,0,240276,9.6875,,Q
895,3,"Wirz, Mr. Albert",male,27,0,0,315154,8.6625,,S
896,3,"Hirvonen, Mrs. Alexander (Helga E Lindqvist)",female,22,1,1,3101298,12.2875,,S
我不明白为什么这个异常用这个较小的数据集指责更少的堆内存,在我的Spark集群中,我正在使用3个worker,并且抛出了异常,有人可以帮助我吗?
PS:我已将spark.driver.memory从默认的512 mb设置为1 gb。