我有一个具有以下结构的DataFrame routes
:
id nodes traveltimes
0 id-1 [node-A, node-B] [6.0]
1 id-2 [node-A, node-C, node-D, node-E] [4.0, 80.0, 38.0]
2 id-3 [node-B, node-D] [90.0]
3 id-4 [node-A] []
4 id-5 [node-A, node-B, node-C, node-D, node-E, node-D] [35.0, 30.0, 110.0, 20.0, 5.0]
.. ... ...
nodes
列中的值列表是图形的节点,traveltimes
列中的值是两个节点之间的时间。每行对应图中的route
。
我想将routes
拆分为阈值traveltimes
。例如,对于70的阈值,我想获得以下结果:
id route_id nodes traveltimes
0 id-1 0 [node-A, node-B] [6.0]
1 id-2 0 [node-A, node-C] [4.0]
2 id-2 1 [node-D, node-E] [38.0]
3 id-3 0 [node-B] []
4 id-3 1 [node-D] []
5 id-4 0 [node-A] []
6 id-5 0 [node-A, node-B, node-C] [35.0, 30.0]
7 id-5 1 [node-D, node-E, node-D] [20.0, 5.0]
.. ... ...
我编写了以下代码,但效率很低。
我有一个分割路线的功能:
def split_routes(row):
newrow = row.copy()
threshold = 70
nodes = newrow['nodes']
traveltimes = newrow['traveltimes']
rows = []
route_id = 0
route_nodes = []
route_traveltimes = []
route_nodes.append(nodes[0])
for i in range(1, len(nodes)):
if(traveltimes[i-1]<threshold):
route_traveltimes.append(traveltimes[i-1])
route_nodes.append(nodes[i])
else :
# Route route_id completed, starting a new one
newrow['route_id'] = route_id
newrow['nodes'] = route_nodes
newrow['traveltimes'] = route_traveltimes
rows.append(newrow)
newrow = row.copy()
route_nodes = []
route_traveltimes = []
route_id+=1
route_nodes.append(nodes[i])
# Route route_id completed
newrow['route_id'] = route_id
newrow['nodes'] = route_nodes
newrow['traveltimes'] = route_traveltimes
rows.append(newrow)
df = pd.DataFrame(rows)
return df
这就是我的用法:
splitted_routes_array = []
for index, row in routes.iterrows(): # Inefficient loop
splitted_routes_array.append(split_routes(row))
splitted_routes = pd.concat(splitted_routes_array).reset_index(drop=True)
我想我可以做一些更有效的方法,而无需自己迭代行。但是我不知道如何使用apply
同时返回多行和多列。
有人可以给我一些提示吗?
答案 0 :(得分:0)
要使pandas中的多个列爆炸,唯一的先决条件是要爆炸的所有列中的列表中元素的数量相同。可以通过-
def get_nodes(x):
if(len(x)<2):
return []
return [[x[i], x[i+1]] for i in range(len(x)-1)]
df['nodes'] = df['nodes'].apply(lambda x: get_nodes(x))
此后,可以使用-
展平数据df = df.set_index('id').apply(lambda x: x.apply(pd.Series).stack()).reset_index().rename(columns={'level_1':'route_id'})
要查找旅行时间大于70.0的所有路线,我们只需-
df[df['traveltimes']>70]