我创建了一个Atom对象,如下所示:
class Atom(object):
def __init__(self, symbol, x, y, z)
self.symbol = symbol
self.position = (x, y, z)
和Selection
类,其中包含按某些条件选择的原子:
class Selection(object):
def __init__(self, a_system, atom_list=[]):
for atom in a_system:
atom_list.append(atom)
self.atom_list = atom_list
def by_symbol(self, symbol):
r_list = []
for atom in self.atom_list:
if atom.symbol is symbol:
r_list.append(atom)
self.atom_list = r_list
def by_zrange(self, zmin, zmax):
r_list = []
for atom in self.atom_list:
pos = atom.position[2]
if pos > zmin and pos < zmax:
r_list.append(atom)
self.atom_list = r_list
所以你可以看到我可以说:
# my_system is a list of atoms objects
group = Selection(my_system)
然后说:
group.by_symbol('H')
我将在对象group
中拥有所有氢原子。如果我这样做:
group.by_zrange(1, 2)
我将在对象group
中拥有z坐标在1和2之间的所有氢原子。
我有其他选择标准,但一般来说它们具有相同的结构,要知道:
r_list = []
for atom in self.atom_list:
# Some criteria here
r_list.append(atom)
self.atom_list = r_list
所以问题是:为了避免为每个选择标准编写上述结构,我能做些什么吗?
如果你知道有一种更简单的方法可以实现我的目的,我会很高兴听到它。
答案 0 :(得分:1)
您可以使用内置的filter()函数,它会自动为您执行循环,并且可以说更优雅:
def by_symbol(self, symbol):
res = filter(lambda atom: atom.symbol == symbol, self.atom_list)
self.atom_list.extend(res)
如果您需要更复杂的过滤,则可能需要编写嵌套函数并传递
而不是lambda
。它应该是单参数函数,并在正确的结果上返回True
。
答案 1 :(得分:1)
以下是如何使用内置filter()函数的工作示例。
下面的代码还包括对您的课程的一些其他增强功能以及对该想法的一些修饰。请特别注意by_symbol()
和by_zrange()
方法以return self
结尾,这样可以更轻松地打印结果并将它们链接在一起,如示例用法所示。
from collections import namedtuple
Point = namedtuple('Point', 'x, y, z')
class Atom(object):
def __init__(self, symbol, x, y, z):
self.symbol = symbol
self.position = Point(x, y, z)
def __repr__(self):
return '{name}({sym!r}, {pos.x}, {pos.y}, {pos.z})'.format(
name=self.__class__.__name__, sym=self.symbol, pos=self.position)
class Selection(object):
def __init__(self, a_system, atom_list=None):
if atom_list is None:
atom_list = []
for atom in a_system:
atom_list.append(atom)
self.atom_list = atom_list
def __repr__(self):
return '{name}({atoms})'.format(
name=self.__class__.__name__, atoms=self.atom_list)
def _filter(self, func):
return filter(func, self.atom_list)
def by_symbol(self, symbol):
self.atom_list = self._filter(lambda a: a.symbol == symbol)
return self
def by_zrange(self, zmin, zmax):
def zrange(a):
return zmin <= a.position.z <= zmax
self.atom_list = self._filter(zrange)
return self
用法示例:
my_system = [Atom('H', 0, 1, 2),
Atom('N', 3, 4, 5),
Atom('C', 6, 7, 8),
Atom('H', 9, 10, 11),]
group = Selection(my_system)
print group
print group.by_symbol('H')
print group.by_zrange(1, 2)
print
group = Selection(my_system)
print group.by_symbol('H').by_zrange(1, 2)
输出:
Selection([Atom('H', 0, 1, 2), Atom('N', 3, 4, 5), Atom('C', 6, 7, 8),
Atom('H', 9, 10, 11)])
Selection([Atom('H', 0, 1, 2), Atom('H', 9, 10, 11)])
Selection([Atom('H', 0, 1, 2)])
Selection([Atom('H', 0, 1, 2)])