如何将PySpark数据帧的每个非字符串列与浮点常量相除或相乘?

时间:2017-06-28 16:18:47

标签: python apache-spark pyspark spark-dataframe pyspark-sql

我的输入数据框如下所示

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Basics").getOrCreate()

df=spark.createDataFrame(data=[('Alice',4.300,None),('Bob',float('nan'),897)],schema=['name','High','Low'])

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice| 4.3|null|
|  Bob| NaN| 897|
+-----+----+----+

预期输出除以10.0

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice| 0.43|null|
|  Bob| NaN| 89.7|
+-----+----+----+

2 个答案:

答案 0 :(得分:6)

我不知道任何可以执行此操作的库函数,但这段代码似乎做得很好:

CONSTANT = 10.0

for field in df.schema.fields:
    if str(field.dataType) in ['DoubleType', 'FloatType', 'LongType', 'IntegerType', 'DecimalType']:
        name = str(field.name)
        df = df.withColumn(name, col(name)/CONSTANT)


df.show()

输出:

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice|0.43|null|
|  Bob| NaN|89.7|
+-----+----+----+

答案 1 :(得分:2)

以下代码应以省时的方式解决您的问题

from pyspark.sql.functions import col

allowed_types = ['DoubleType', 'FloatType', 'LongType', 'IntegerType', 'DecimalType']

df = df.select(*[(col(field.name)/10).name(field.name) if str(field.dataType) in allowed_types else col(field.name) for field in df.schema.fields]

当列数很大时,迭代使用“ withColumn”可能不是一个好主意。
这是因为PySpark数据帧是不可变的,因此从本质上讲,我们将为使用withColumn强制转换的每一列创建一个新的DataFrame,这将是一个非常缓慢的过程。

这是上面的代码很方便的地方。