首先,我想声明我不能使用熊猫。我要尝试的是在单元格的值与先前确定的特定值匹配时替换数据帧的单元格的值。否则,将单元格的值保留为原来的值。
这是我到目前为止尝试过的:
predictions = crossval.fit(trainingData).transform(trainingData)
bins = predictions.select("prediction").distinct().collect()
for row in bins:
rows = predictions.select(["features", "prediction"]).filter(predictions.prediction == row.prediction).withColumnRenamed("prediction", "prediction_1")
dt_model = dt.fit(rows)
dt_transform = dt_model.transform(testData).select("prediction")
predictions = predictions.withColumn("prediction", when(predictions.prediction == rows.prediction_1, dt_transform.prediction).otherwise(predictions.prediction))
给我麻烦的那一行是:
predictions = predictions.withColumn("prediction", when(predictions.prediction == rows.prediction_1, dt_transform.prediction).otherwise(predictions.prediction))
它给我的错误是:
Traceback (most recent call last):
File "part2.py", line 114, in <module>
main()
File "part2.py", line 108, in main
predictions = predictions.withColumn("prediction", when(predictions.prediction == rows.prediction_1, dt_transform.prediction).otherwise(predictions.prediction))
File "/opt/spark/python/pyspark/sql/dataframe.py", line 1990, in withColumn
return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
File "/opt/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
File "/opt/spark/python/pyspark/sql/utils.py", line 69, in deco
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
pyspark.sql.utils.AnalysisException: u'Resolved attribute(s) prediction#3065,prediction_1#2949 missing from features#200,trip_duration#20,prediction#2925 in operator !Project [features#200, trip_duration#20, CASE WHEN (prediction#2925 = prediction_1#2949) THEN prediction#3065 ELSE prediction#2925 END AS prediction#3070]. Attribute(s) with the same name appear in the operation: prediction. Please check if the right attribute(s) are used.;;\n!Project [features#200, trip_duration#20, CASE WHEN (prediction#2925 = prediction_1#2949) THEN prediction#3065 ELSE prediction#2925 END AS prediction#3070]\n+- Project [features#200, trip_duration#20, UDF(features#200) AS prediction#2925]\n +- Sample 0.0, 0.8, false, 3709578444707833222\n +- Sort [features#200 ASC NULLS FIRST, trip_duration#20 ASC NULLS FIRST], false\n +- Project [features#200, trip_duration#20]\n +- Project [vendor_id#11, passenger_count#14, store_and_fwd_flag#178, trip_duration#20, distance#33, year#98, month#107, day#117, hour#128, minute#140, second#153, UDF(named_struct(vendor_id_double_VectorAssembler_42efd84316ac, cast(vendor_id#11 as double), passenger_count_double_VectorAssembler_42efd84316ac, cast(passenger_count#14 as double), store_and_fwd_flag_double_VectorAssembler_42efd84316ac, cast(store_and_fwd_flag#178 as double), distance_double_VectorAssembler_42efd84316ac, cast(distance#33 as double), year_double_VectorAssembler_42efd84316ac, cast(year#98 as double), month_double_VectorAssembler_42efd84316ac, cast(month#107 as double), day_double_VectorAssembler_42efd84316ac, cast(day#117 as double), hour_double_VectorAssembler_42efd84316ac, cast(hour#128 as double), minute_double_VectorAssembler_42efd84316ac, cast(minute#140 as double), second_double_VectorAssembler_42efd84316ac, cast(second#153 as double))) AS features#200]\n +- Project [vendor_id#11, passenger_count#14, <lambda>(store_and_fwd_flag#19) AS store_and_fwd_flag#178, trip_duration#20, distance#33, year#98, month#107, day#117, hour#128, minute#140, second#153]\n +- Project [vendor_id#11, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33, year#98, month#107, day#117, hour#128, minute#140, second#153]\n +- Project [vendor_id#11, pickup_datetime#12, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33, year#98, month#107, day#117, hour#128, minute#140, <lambda>(pickup_datetime#12) AS second#153]\n +- Project [vendor_id#11, pickup_datetime#12, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33, year#98, month#107, day#117, hour#128, <lambda>(pickup_datetime#12) AS minute#140]\n +- Project [vendor_id#11, pickup_datetime#12, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33, year#98, month#107, day#117, <lambda>(pickup_datetime#12) AS hour#128]\n +- Project [vendor_id#11, pickup_datetime#12, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33, year#98, month#107, <lambda>(pickup_datetime#12) AS day#117]\n +- Project [vendor_id#11, pickup_datetime#12, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33, year#98, <lambda>(pickup_datetime#12) AS month#107]\n +- Project [vendor_id#11, pickup_datetime#12, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33, <lambda>(pickup_datetime#12) AS year#98]\n +- Project [vendor_id#11, pickup_datetime#12, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33]\n +- Project [vendor_id#11, pickup_datetime#12, dropoff_datetime#13, passenger_count#14, store_and_fwd_flag#19, trip_duration#20, distance#33]\n +- Project [vendor_id#11, pickup_datetime#12, dropoff_datetime#13, passenger_count#14, dropoff_latitude#18, store_and_fwd_flag#19, trip_duration#20, distance#33]\n +- Project [vendor_id#11, pickup_datetime#12, dropoff_datetime#13, passenger_count#14, dropoff_longitude#17, dropoff_latitude#18, store_and_fwd_flag#19, trip_duration#20, distance#33]\n +- Project [vendor_id#11, pickup_datetime#12, dropoff_datetime#13, passenger_count#14, pickup_latitude#16, dropoff_longitude#17, dropoff_latitude#18, store_and_fwd_flag#19, trip_duration#20, distance#33]\n +- Project [vendor_id#11, pickup_datetime#12, dropoff_datetime#13, passenger_count#14, pickup_longitude#15, pickup_latitude#16, dropoff_longitude#17, dropoff_latitude#18, store_and_fwd_flag#19, trip_duration#20, distance#33]\n +- Project [id#10, vendor_id#11, pickup_datetime#12, dropoff_datetime#13, passenger_count#14, pickup_longitude#15, pickup_latitude#16, dropoff_longitude#17, dropoff_latitude#18, store_and_fwd_flag#19, trip_duration#20, <lambda>(pickup_longitude#15, pickup_latitude#16, dropoff_longitude#17, dropoff_latitude#18) AS distance#33]\n +- Relation[id#10,vendor_id#11,pickup_datetime#12,dropoff_datetime#13,passenger_count#14,pickup_longitude#15,pickup_latitude#16,dropoff_longitude#17,dropoff_latitude#18,store_and_fwd_flag#19,trip_duration#20] csv\n'
到目前为止,我已经意识到的是,我将rows.prediction_1和dt_transform.prediction替换为predictions.prediction,然后开始工作。只是不像它应该的那样。因此,这两个数据框出了问题。
predictions.show()的输出为:
+--------------------+-------------+------------------+
| features|trip_duration| prediction|
+--------------------+-------------+------------------+
|[1.0,0.0,0.0,0.0,...| 8| 299.6655053883315|
|[1.0,0.0,0.0,0.02...| 9| 299.6655053883315|
|[1.0,0.0,0.0,15.1...| 2251|2659.7614115841966|
|[1.0,1.0,0.0,0.0,...| 37| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 1084| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 570| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 599| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 21| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 6| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 19| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 177| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 44| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 35| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 60| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 79| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 73| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 705| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 580| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 67| 299.6655053883315|
|[1.0,1.0,0.0,0.0,...| 640| 299.6655053883315|
+--------------------+-------------+------------------+
答案 0 :(得分:1)
备注1:dt_transform = dt_model.transform(testData).select("prediction")
的意义不大,因为 testData 和行行数不匹配。您将无法使用when
函数在下一行将 testData 新的预测重新分配回行预测,因为它逐行操作。 join
函数将是更好的选择。
注释2:predictions = predictions.withColumn("prediction", when(predictions.prediction == rows.prediction_1, dt_transform.prediction).otherwise(predictions.prediction))
是非法的。对于该操作,您不能有多个数据帧(您有三个-预测,行, dt_transform )。如果要获取或比较其他数据框中的值,则可以使用join
函数。
这是我创建的一个简短示例,用于指导您采用两阶段估算方法。
第1阶段-对整个数据进行估算以获得初步预测。
第2阶段-将数据组织成子组(按初步预测分组),重新估计和更新预测。
注意:我正在使用分类进行演示,但是,我的示例可以适应您的回归案例。
代码
from pyspark.sql.types import StructField, StructType, DoubleType
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans
import pyspark.sql.functions as F
#create a sample data frame
data = [(1.54,3.45,2.56),(9.39,8.31,1.34),(1.25,3.31,9.87),(9.35,5.67,2.49),\
(1.23,4.67,8.91),(3.56,9.08,7.45),(6.43,2.23,1.19),(7.89,5.32,9.08)]
fields = [StructField('a', DoubleType(),True),
StructField('b', DoubleType(),True),
StructField('c', DoubleType(),True)
]
schema = StructType(fields)
df = spark.createDataFrame(data, schema)
df.show()
# +----+----+----+
# | a| b| c|
# +----+----+----+
# |1.54|3.45|2.56|
# |9.39|8.31|1.34|
# |1.25|3.31|9.87|
# |9.35|5.67|2.49|
# |1.23|4.67|8.91|
# |3.56|9.08|7.45|
# |6.43|2.23|1.19|
# |7.89|5.32|9.08|
# +----+----+----+
#Stage 1
assembler = VectorAssembler(inputCols=['a','b','c'],outputCol='features')
df_trans = assembler.transform(df)
kmeans = KMeans(k=3, seed=123)
km_model = kmeans.fit(df_trans)
predictions = km_model.transform(df_trans)
predictions.orderBy('prediction').show()
# +----+----+----+----------------+----------+
# | a| b| c| features|prediction|
# +----+----+----+----------------+----------+
# |1.25|3.31|9.87|[1.25,3.31,9.87]| 0|
# |1.23|4.67|8.91|[1.23,4.67,8.91]| 0|
# |3.56|9.08|7.45|[3.56,9.08,7.45]| 0|
# |7.89|5.32|9.08|[7.89,5.32,9.08]| 0|
# |9.39|8.31|1.34|[9.39,8.31,1.34]| 1|
# |9.35|5.67|2.49|[9.35,5.67,2.49]| 1|
# |1.54|3.45|2.56|[1.54,3.45,2.56]| 2|
# |6.43|2.23|1.19|[6.43,2.23,1.19]| 2|
# +----+----+----+----------------+----------+
# Stage 2
bins = predictions.select("prediction").distinct().collect()
count = 0
for row in bins:
count+=1
#create a sub dataframe for each unique prediction and re-estimate
sub_df = (predictions.filter(F.col('prediction')==row.prediction)
.select(['features','prediction'])
.withColumnRenamed('prediction','previous_prediction')
)
sub_model = kmeans.fit(sub_df)
sub_predictions = sub_model.transform(sub_df)
#initialize if it is the first loop iteration, otherwise merge (union) rows
if count==1:
updated_predictions = sub_predictions
else:
updated_predictions = updated_predictions.union(sub_predictions)
输出
updated_predictions.orderBy('previous_prediction').withColumnRenamed('prediction','updated_prediction').show()
# +----------------+-------------------+------------------+
# | features|previous_prediction|updated_prediction|
# +----------------+-------------------+------------------+
# |[1.25,3.31,9.87]| 0| 1|
# |[1.23,4.67,8.91]| 0| 1|
# |[3.56,9.08,7.45]| 0| 0|
# |[7.89,5.32,9.08]| 0| 2|
# |[9.39,8.31,1.34]| 1| 0|
# |[9.35,5.67,2.49]| 1| 1|
# |[1.54,3.45,2.56]| 2| 0|
# |[6.43,2.23,1.19]| 2| 1|
# +----------------+-------------------+------------------+