选择对象

时间:2013-11-02 00:35:33

标签: python

我创建了一个Atom对象,如下所示:

class Atom(object):
    def __init__(self, symbol, x, y, z)
        self.symbol = symbol
        self.position = (x, y, z)

Selection类,其中包含按某些条件选择的原子:

class Selection(object):
    def __init__(self, a_system, atom_list=[]):
        for atom in a_system:
            atom_list.append(atom)
        self.atom_list = atom_list

    def by_symbol(self, symbol):
        r_list = []
        for atom in self.atom_list:
            if atom.symbol is symbol:
                r_list.append(atom)
        self.atom_list = r_list

    def by_zrange(self, zmin, zmax):
        r_list = []
        for atom in self.atom_list:
            pos = atom.position[2]
            if pos > zmin and pos < zmax:
                r_list.append(atom)
        self.atom_list = r_list

所以你可以看到我可以说:

# my_system is a list of atoms objects
group = Selection(my_system)

然后说:

group.by_symbol('H')

我将在对象group中拥有所有氢原子。如果我这样做:

group.by_zrange(1, 2)

我将在对象group中拥有z坐标在1和2之间的所有氢原子。

我有其他选择标准,但一般来说它们具有相同的结构,要知道:

r_list = []
for atom in self.atom_list:
    # Some criteria here
        r_list.append(atom)
self.atom_list = r_list

所以问题是:为了避免为每个选择标准编写上述结构,我能做些什么吗?

如果你知道有一种更简单的方法可以实现我的目的,我会很高兴听到它。

2 个答案:

答案 0 :(得分:1)

您可以使用内置的filter()函数,它会自动为您执行循环,并且可以说更优雅:

def by_symbol(self, symbol):
    res = filter(lambda atom: atom.symbol == symbol, self.atom_list)
    self.atom_list.extend(res)

如果您需要更复杂的过滤,则可能需要编写嵌套函数并传递 而不是lambda。它应该是单参数函数,并在正确的结果上返回True

答案 1 :(得分:1)

以下是如何使用内置filter()函数的工作示例。

下面的代码还包括对您的课程的一些其他增强功能以​​及对该想法的一些修饰。请特别注意by_symbol()by_zrange()方法以return self结尾,这样可以更轻松地打印结果并将它们链接在一起,如示例用法所示。

from collections import namedtuple

Point = namedtuple('Point', 'x, y, z')

class Atom(object):
    def __init__(self, symbol, x, y, z):
        self.symbol = symbol
        self.position = Point(x, y, z)

    def __repr__(self):
        return '{name}({sym!r}, {pos.x}, {pos.y}, {pos.z})'.format(
            name=self.__class__.__name__, sym=self.symbol, pos=self.position)

class Selection(object):
    def __init__(self, a_system, atom_list=None):
        if atom_list is None:
            atom_list = []
        for atom in a_system:
            atom_list.append(atom)
        self.atom_list = atom_list

    def __repr__(self):
        return '{name}({atoms})'.format(
            name=self.__class__.__name__, atoms=self.atom_list)

    def _filter(self, func):
        return filter(func, self.atom_list)

    def by_symbol(self, symbol):
        self.atom_list = self._filter(lambda a: a.symbol == symbol)
        return self

    def by_zrange(self, zmin, zmax):
        def zrange(a):
            return zmin <= a.position.z <= zmax
        self.atom_list = self._filter(zrange)
        return self

用法示例:

my_system = [Atom('H', 0, 1, 2),
             Atom('N', 3, 4, 5),
             Atom('C', 6, 7, 8),
             Atom('H', 9, 10, 11),]

group = Selection(my_system)
print group
print group.by_symbol('H')
print group.by_zrange(1, 2)
print
group = Selection(my_system)
print group.by_symbol('H').by_zrange(1, 2)

输出:

Selection([Atom('H', 0, 1, 2), Atom('N', 3, 4, 5), Atom('C', 6, 7, 8),
           Atom('H', 9, 10, 11)])
Selection([Atom('H', 0, 1, 2), Atom('H', 9, 10, 11)])
Selection([Atom('H', 0, 1, 2)])

Selection([Atom('H', 0, 1, 2)])