合并类对象列表

时间:2019-02-10 18:31:09

标签: python python-3.x

我想比较两个包含class Version(object)对象的列表,以便将一个合并到另一个列表中,但跳过重复项,但是python似乎认为两个Version()对象是相同的,即使它们的内容不同。 >

我尝试按照https://stackoverflow.com/a/1227325/10881866

的指示为对象提供自定义的“比较”方法

这是我要比较的课程:

class Version(object):
    valid_version = False
    version = None
    valid_platform = False
    platform = None
    valid_sign = False
    sign = None
    def __init__(self, version, platform, sign):
        version_match = search(version_pattern, version)
        if (version_match): self.version = version_match.string; self.valid_version = True
        else: self.version = version
        self.platform = platform
        self.valid_platform = platform in platforms
        sign_match = search(sign_pattern, sign)
        if (sign_match): self.sign = sign_match.string; self.valid_sign = True
        else: self.sign = sign
    def __str__(self): return str(self.__dict__)
    # def __eq__(self, other): return self.sign == other.sign

这是我用于合并的辅助函数(也可以在SO上找到):

def merge_no_duplicates(iterable_1, iterable_2):
    myset = set(iterable_1).union(set(iterable_2))
    return list(myset)

这是我合并列表的部分:

try:
        remote_versions = getVersionsFromRemote()
        logger.info("Loaded {} remote versions".format(len(remote_versions)))
        versions = merge_no_duplicates(versions, remote_versions)
except: logger.error("Can't load remote versions!")
try:
        local_versions = getVersionsFromLocal()
        logger.info("Loaded {} local versions".format(len(local_versions)))
        versions = merge_no_duplicates(versions, local_versions)
except: logger.error("Can't load local versions!")
versions = list(filter(None, versions))
logger.info("Got {} versions total.".format(len(versions)))

预期:

2019-02-10 19:14:38,220|INFO    | Loaded 156 remote versions
2019-02-10 19:14:38,223|INFO    | Loaded 156 local versions
2019-02-10 19:14:38,223|INFO    | Got 156 versions total.

实际:

2019-02-10 19:14:38,220|INFO    | Loaded 156 remote versions
2019-02-10 19:14:38,223|INFO    | Loaded 156 local versions
2019-02-10 19:14:38,223|INFO    | Got 312 versions total.

1 个答案:

答案 0 :(得分:1)

如果您希望set删除重复项,则需要定义__eq____hash__方法。这是一个简单的示例:

class WithoutMethods:
    def __init__(self, a, b):  # Note no class-level attribute declaration
        self.a = a
        self.b = b
    def __repr__(self):
        return "WithoutMethods({0.a}, {0.b})".format(self)

class WithMethods:
    def __init__(self, a, b):
        self.a = a
        self.b = b
    def __repr__(self):
        return "WithMethods({0.a}, {0.b})".format(self)
    def __eq__(self, other):
        if not isinstance(other, WithMethods):
            return NotImplemented
        return (self.a, self.b) == (other.a, other.b)
    def __hash__(self):
        return hash((self.a, self.b))  # There are lots of ways to define hash methods.
                                       # This is the simplest, but may lead to collisions 

print({WithoutMethods(1, 2), WithoutMethods(1, 2)})
# {WithoutMethods(1, 2), WithoutMethods(1, 2)}
print({WithMethods(1, 2), WithMethods(1, 2)})
# {WithMethods(1, 2)}

这是由于set(和dict)如何存储其值。当您将一个对象添加到集合中时,该集合不会将其与集合中的所有其他对象进行比较以确定它是否重复。相反,它使用对象的哈希值跳到集合中的适当位置,然后检查那里的对象是否已经有一个。 (这是一种简化,因为有时不相等的对象具有相同的哈希值)。即使您拥有__eq__方法,如果对象具有不同的哈希值,则该集合也将永远不会比较对象。