当属性是列表时,获取类对象的子样本

时间:2014-10-01 22:36:24

标签: python class

我对Python课程有疑问,似乎无法在任何地方找到简单的答案。所以我要说我定义了一个类:

class point(object):
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
    def calc_mag(self):
        self.mag = np.sqrt(self.x*self.x + self.y*self.y + self.z*self.z)

现在我可以通过以下方式轻松创建对象列表:

xs = [1,2,3,4,5]
ys = [2,3,4,5,6]
zs = [3,4,5,6,7]
points = []
for i in range(len(xs)):
    pt = point(xs[i], ys[i], zs[i])
    points.append(pt)

我可以通过

获取这些点对象的子样本
sub_points = [pt for pt in points if pt.x > 1.0]

这有效,但创建部分效率不高,因为我们使用循环而不是向量化。更快的方法就是

points = point(xs, ys, zs)

当我引用属性x时,我得到一个值列表:

in : points.x
out: [1, 2, 3, 4, 5]

我的问题是,对于这个类对象(它本质上是列表的对象而不是对象列表),是否有一种快速获取子样本的方法,就像上面的第一种方法一样?我尝试了一些像

这样的东西
points[points.x > 1]  # Wrong way of doing it

但由于点不是列表,因此无法编入索引并引发错误

当然我也可以应用比较测试然后通过过滤所有其他属性来重新创建对象,但这又是非常低效的并且产生冗余代码。

那么有人知道如何解决这个问题吗?

===================(附加信息)======================== ==

感谢所有回复的人。我想也许我需要在这里澄清一些事情。 上面发布的类不是我的程序中使用的实际类。我发布了一个简化版本,以便更容易和更简单地讨论真正的问题。我使用的实际类更大,更复杂,有40多个属性和方法。话虽如此,我必须保持课堂上的东西以利用漂亮的功能,使用颠簸的数组,pandas数据框或列表推导根本不是一个选项。

此外,性能有点重要,这就是我使用矢量化形式而不是列表推导或循环创建类的原因。我可以用C / C ++编写它仅仅是为了提高性能,但是有一些关于Python的好东西,这使得暂时坚持使用python是有益的。我还可以为最慢的部分编写一个C包装器以提高性能并绕过这个问题,但不知怎的,我觉得我必须在Python中找到解决方案!

3 个答案:

答案 0 :(得分:1)

这在很大程度上取决于应用程序,但像numpy数组这样的数组很适合给定的示例。

import numpy as np

xs = [1,2,3,4,5]
ys = [2,3,4,5,6]
zs = [3,4,5,6,7]
points = np.array([xs, ys, zs]).T  # transpose so rows are points

print(points[points[:, 0] > 1])
# [[2 3 4]
#  [3 4 5]
#  [4 5 6]
#  [5 6 7]]

您甚至可以使用struct arrays来保留标签。

points = np.array(
    [p for p in zip(xs, ys, zs)], 
    dtype= {'names': ['x', 'y', 'z'], 'formats': ['i4']*3}  # i4 for ints
)

print(points[points['x'] > 1])
# [(2, 3, 4) (3, 4, 5) (4, 5, 6) (5, 6, 7)]

如果要保持同一个类访问语法points.x,可以将一个numpy数组包装在一个类中,并添加访问该数组的各个列的属性。请参阅子类ndarray上的documentation

答案 1 :(得分:0)

您要做的是称为布尔索引。 Numpy阵列固有地支持这一点。如果需要标记数组,也可以考虑使用pandas库(想想excel表格数据:带有行标签和列标签的数组)。

您尝试做的事情的问题是您需要自定义对象来支持布尔索引,而python对象不支持此功能。如果您绝对需要自定义行为,则可以子类化numpy数组并重载其控制布尔索引的魔术方法。编辑:您也可以尝试使用其他解决方案指出的记录数组。

http://docs.scipy.org/doc/numpy/user/basics.subclassing.html

这是大熊猫的解决方案。与numpy不同,它支持属性索引。

from pandas import DataFrame
df = DataFrame([[1,2,3], [2,3,4], [3,4,5]], columns=['xs', 'ys', 'zs'])
df

   xs  ys  zs
0   1   2   3
1   2   3   4
2   3   4   5

然后您可以在xs

上编制索引
df['xs'] > 1
0    False
1     True
2     True

Name: xs, dtype: bool
df[df['xs'] > 1]
    xs  ys  zs
1   2   3   4
2   3   4   5

答案 2 :(得分:0)

你提出的问题很少。第一个是理解的创造:

from itertools import izip

class point(object):
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
    def __str__(self):
        return 'P({s.x}, {s.y}, {s.z})'.format(s=self)
    def __repr__(self):
        return str(self)

vectors = izip(xs, ys, zs)
points = [point(*vector) for vector in vectors]
print points

如果您不想使用numpy或pandas容器,您可以使用理解或过滤:

print [p for p in points if p.x < 3]
print filter(lambda p: p.x < 3, points)
filt = lambda p: p.x < 3
print filter(filt, points)

此外,使用模块operatorfunctools,您可以为这些过滤器制作工厂。