TypeError:浮点数是必需的pyspark

时间:2018-07-31 09:33:32

标签: python python-3.x apache-spark pyspark

我旨在计算经纬度数据的正弦距离,我已将uci数据集用于相同的https://archive.ics.uci.edu/ml/datasets/GPS+Trajectories(go_track_trackspoints.csv) 我使用下面的代码来计算距离

def dist(lon2, lat2,lon1, lat1):
            phi_1=toRadians(lat1)
            phi_2=toRadians(lat2)
            delta_phi=toRadians(lat2-lat1)
            delta_lambda=toRadians(lon2-lon1)


            a=sin(delta_phi/2.0)**2+cos(phi_1)*cos(phi_2)*sin(delta_lambda/2.0)**2
            c=2*atan2(sqrt(abs(a)),sqrt(abs((1-a))))
            return c * 6372.8

模式是

root
 |-- id: string (nullable = true)
 |-- latitude: string (nullable = true)
 |-- longitude: string (nullable = true)
 |-- track_id: string (nullable = true)
 |-- time: string (nullable = true)

我已将数据加载到spark数据帧(gps_data)

 +---+-----------------+-----------------+--------+-------------------+
    | id|         latitude|        longitude|track_id|               time|
    +---+-----------------+-----------------+--------+-------------------+
    |  1|-10.9393413858164|-37.0627421097422|       1|2014-09-13 07:24:32|
    |  2| -10.939341385769|-37.0627421097809|       1|2014-09-13 07:24:37|
    |  3|-10.9393239478718|-37.0627645137212|       1|2014-09-13 07:24:42|
    |  4|-10.9392105616561|-37.0628430455445|       1|2014-09-13 07:24:47|
    +---+-----------------+-----------------+--------+-------------------+

使用以下命令,我试图获取一列距离

my_window = Window.partitionBy().orderBy("time")  
gps_d=gps_data.withColumn("dist", dist(
        "longitude", "latitude",
        lag("longitude", 1).over(my_window), lag("latitude", 1).over(my_window)
    ).alias("dist"))

但是我也无法解决错误,也无法找到解决方案。请帮助我!

错误是:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-219-6352041ff223> in <module>()
      1 gps_d=gps_data.withColumn("dist", dist(
      2     "longitude", "latitude",
----> 3     lag("longitude", 1).over(my_window), lag("latitude", 1).over(my_window)
      4 ).alias("dist"))

<ipython-input-218-2f781ab3b2fb> in dist(lon2, lat2, lon1, lat1)
      9 
     10         a=sin(delta_phi/2.0)**2+cos(phi_1)*cos(phi_2)*sin(delta_lambda/2.0)**2
---> 11         c=2*atan2(sqrt(a),sqrt(1-a))
     12         return c * 6372.8

TypeError: a float is required

PS:在将csv加载到数据帧时,我检查了该列是否都不为空

1 个答案:

答案 0 :(得分:1)

我的错误没有被复制,但是输出也全部为空。所以我通过列而不是字符串,它的工作方式是

gps_data.withColumn("dist", dist(
    col("longitude"), col("latitude"),
    lag("longitude", 1).over(my_window), lag("latitude", 1).over(my_window)
).alias("dist")).show(truncate=False)

输出为

+---+-----------------+-----------------+--------+-------------------+---------------------+
|id |latitude         |longitude        |track_id|time               |dist                 |
+---+-----------------+-----------------+--------+-------------------+---------------------+
|1  |-10.9393413858164|-37.0627421097422|1       |2014-09-13 07:24:32|null                 |
|2  |-10.939341385769 |-37.0627421097809|1       |2014-09-13 07:24:37|6.756720378438061E-9 |
|3  |-10.9393239478718|-37.0627645137212|1       |2014-09-13 07:24:42|0.0031221549946508337|
|4  |-10.9392105616561|-37.0628430455445|1       |2014-09-13 07:24:47|0.01525123103019258  |
+---+-----------------+-----------------+--------+-------------------+---------------------+

我希望它对您有帮助