numpy过滤器在每个元素上的使用条件

时间:2018-09-18 07:18:53

标签: python numpy

我有一个过滤器表达式,如下所示:

feasible_agents = filter(lambda agent: agent >= cost[task, agent], agents)

其中agents是python列表。

现在,为了加快速度,我正在尝试使用numpy来实现。

使用numpy相当于什么?

我知道这可行:

threshold = 5.0
feasible_agents = np_agents[np_agents > threshold]

其中np_agentsagents的numpy等效项。

但是,我希望阈值是numpy数组中每个元素的函数。

2 个答案:

答案 0 :(得分:1)

由于您没有提供示例数据,请使用玩具数据:

# Cost of agents represented by indices of cost, we have agents 0, 1, 2, 3
cost = np.array([4,5,6,2])
# Agents to consider 
np_agents = np.array([0,1,3])
# threshold for each agent. Calculate different thresholds for different agents. Use array of indexes np_agents into cost array.
thresholds = cost[np_agents] # np.array([4,5,2])
feasible_agents = np_agents[np_agents > thresholds] # np.array([3])

答案 1 :(得分:0)

您可以使用numpy.extract

SELECT * FROM
(SELECT 310 AS code UNION SELECT 350 UNION SELECT 301 UNION SELECT 302) AS t1
WHERE NOT EXISTS(SELECT 1 FROM customer
             WHERE code = t1.code)

numpy.where

>>> nparr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
>>> nparreven = np.extract(nparr % 2 == 0, nparr)