如何计算每行的最大值,并返回最大值的列和具有相应列名的另一列?

时间:2017-09-12 14:46:51

标签: dataframe pyspark

我得到了一个Pyspark DF,DF =

+--------------------+--------------------+--------------------+
|           avg(F[0])|           avg(F[1])|           avg(F[2])|
+--------------------+--------------------+--------------------+
|                                                              |
|   1.728054550127892| -1.3667432679930283|  1.5750969709542757|
| -1.2435554666317885|  1.9235631250642942| -1.6640951322277968|
| 0.16982762083959863| -0.1535834478084156|  0.3475326658630229|
|   0.477355880659821|  -1.234290049506637|  0.4749928518454093|
| -0.5508890265873237|    1.13481924081605| -0.4360033650587705|
|  1.2016720679403226|  -0.586277913445618|   0.804378397997815|
|-0.23824636151441825| -0.4287653456589916| 0.04471521195350969|
|-0.20732428129005576| -0.4924735928530498|  0.2833979281053236|
| -1.2190324811595223|  1.3641885141303651| -1.6970489822900245|
|  0.6666003701714154|-0.44153017341535095|  1.0085654202707803|
|  0.3770586483507039| -1.1370481655269318|-0.03587840129806219|
|-0.21321645372638362|  1.0865405975548958| 0.11085557630922299|
|  -1.072398519603297|  0.8355439322641093|  -0.241882607400929|
|  0.6389977183433129| -0.5350431348677808| -0.8180005819445212|
| -0.6431203520333384| 0.10044676372867167|  0.6448699599709696|
|  0.5642782985741281| 0.11060183842132806| -0.6709698499147829|
|  0.5864417674723157|  0.3384703463140547|  0.6871676808317047|
| -0.7202689145159678|-0.41922780383853375|-0.29109205561252516|
|-0.19525347142570315|  0.2013279865961808| 0.14113208947213507|
+--------------------+--------------------+--------------------+

我想要的是计算每一行的最大值并返回一个包含2个新列的新DataFrame DF:" maxValue"包含最大值," maxColum"包含相应的列名?

有什么想法吗?

提前致谢

1 个答案:

答案 0 :(得分:3)

希望这有帮助!

from pyspark.sql.functions import col, greatest, udf, array
from pyspark.sql.types import StringType

df = sc.parallelize([(1.728054550127892, -1.3667432679930283, 1.5750969709542757),
                     (-1.2435554666317885, 1.9235631250642942, -1.6640951322277968),
                     (0.16982762083959863, -0.1535834478084156, 0.3475326658630229)]).\
    toDF(["col1", "col2","col3"])

df1 = df.withColumn("maxValue", greatest(*[col(x) for x in df.columns]))
col_arr = df1.columns

def modify_values(r):
    for i in range(len(r[:-1])):
        if r[i]==r[-1]:
            return col_arr[i]
modify_values_udf = udf(modify_values, StringType())
df1 = df1.withColumn("maxColumn", modify_values_udf(array(df1.columns)))
df1.show()

输出是:

+-------------------+-------------------+-------------------+------------------+---------+
|               col1|               col2|               col3|          maxValue|maxColumn|
+-------------------+-------------------+-------------------+------------------+---------+
|  1.728054550127892|-1.3667432679930283| 1.5750969709542757| 1.728054550127892|     col1|
|-1.2435554666317885| 1.9235631250642942|-1.6640951322277968|1.9235631250642942|     col2|
|0.16982762083959863|-0.1535834478084156| 0.3475326658630229|0.3475326658630229|     col3|
+-------------------+-------------------+-------------------+------------------+---------+