在对象属性的上下文中,numpy数组无法与==
相比(使用np.array_equal
的语义),我遇到问题。
考虑以下示例:
>>> import numpy as np
>>> class A:
... def __init__(self, a):
... self.a = a
... def __eq__(self, other):
... return self.__dict__ == other.__dict__
...
>>> x = A(a=[1, np.array([1, 2])])
>>> y = A(a=[1, np.array([1, 2])])
>>> x == y
Traceback (most recent call last):
File "<ipython-input-33-9cfbd892cdaa>", line 1, in <module>
x == y
File "<ipython-input-30-790950997d4f>", line 5, in __eq__
return self.__dict__ == other.__dict__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
(忽略__eq__
并不完美,它至少应该检查other
的类型,但这是为了简洁起见)
我将如何实现一个__eq__
函数来处理嵌套在对象属性深处的numpy数组(假设其他所有内容(例如本示例中的列表)与==
比较好)? numpy数组可能出现在列表,元组或字典中的任意深度嵌套级别。
我尝试提出递归eq
函数的“手动”实现,该函数将==
应用于所有属性,并在遇到numpy数组时使用np.array_equal
,但这比预期的要复杂。
有人有合适的功能或简单的解决方法吗?
答案 0 :(得分:0)
如果可以选择更改对象x
和y
,则可以根据自己的喜好覆盖__eq__
的{{1}}方法。
np.ndarray
此结果显示在class eqarr(np.ndarray):
def __eq__(self, other):
return np.array_equal(self, other)
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
return self.__dict__ == other.__dict__
x = A(a=[1, eqarr([1, 2])])
y = A(a=[1, eqarr([1, 2])])
x == y
中。
如果这不可能,那么我目前能想到的唯一解决方案是实际实现递归的相等性检查功能。我的尝试如下:
True
有了您的示例以及我提出的所有示例,它们都能奏效。只要嵌套对象具有def eq(a, b):
if not (hasattr(a, '__iter__') or type(a) == str):
return a == b
try:
if not len(a) == len(b):
return False
if type(a) == np.ndarray:
return np.array_equal(a, b)
if isinstance(a, dict):
return all(eq(v, b[k]) for k, v in a.items())
else:
return all(eq(aa, bb) for aa, bb in zip(a, b))
except (TypeError, KeyError):
return False
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
return eq(self.__dict__, other.__dict__)
和__iter__
属性,该解决方案就应该适用。
我希望我已经解决了所有可能的错误,但是您可能需要稍微调整一下代码以使其绝对安全。
如果找到反例,请提供它作为注释。我确定代码可以相应地进行调整。
__len__
的表现可能不佳,但是我不知道这是否是您最关心的问题。
如果numpy数组在您的层次结构中很少使用(并且通常接近顶部),则始终可以首先尝试进行常规比较。看起来可能如下所示:
eq