我是一个使用Python的初学者,我试图在字典中使用搜索功能来搜索具有点的坐标(2)的numpy数组的键。所以,我想要的是:一个字典,其键是numpy数组,其值是整数。然后使用in运算符来使用一些容差度量(numpy.allclose函数)来比较键。我知道numpy数组不是hashables所以我必须覆盖 getitem 和 setitem 函数(基于我在How to properly subclass dict and override __getitem__ & __setitem__中找到的)。但是如何让这些可以在字典中添加它们作为键呢?在这种情况下,如何覆盖in运算符的行为?
感谢您的帮助!
答案 0 :(得分:1)
Numpy数组不可清除但是元组是。因此,如果将数组转换为元组,则可以对数组进行散列。从理论上讲,如果你事先对它进行舍入,你可以利用快速查找,因为你现在有了离散点。但是你会在重新翻译时遇到分辨率问题,因为舍入是用十进制基数完成的,但数字是存储二进制的。可以通过将其转换为缩放的整数来避免这种情况,但这会使所有内容减慢一些。
最后,你只需要编写一个在数组和元组之间来回转换的类,你就可以了。你很高兴。 实现可能如下所示:
import numpy as np
class PointDict(dict):
def __init__(self, precision=5):
super(PointDict, self).__init__()
self._prec = 10**precision
def decode(self, tup):
"""
Turns a tuple that was used as index back into a numpy array.
"""
return np.array(tup, dtype=float)/self._prec
def encode(self, ndarray):
"""
Rounds a numpy array and turns it into a tuple so that it can be used
as index for this dict.
"""
return tuple(int(x) for x in ndarray*self._prec)
def __getitem__(self, item):
return self.decode(super(PointDict, self).__getitem__(self.encode(item)))
def __setitem__(self, item, value):
return super(PointDict, self).__setitem__(self.encode(item), value)
def __contains__(self, item):
return super(PointDict, self).__contains__(self.encode(item))
def update(self, other):
for item, value in other.items():
self[item] = value
def items(self):
for item in self:
yield (item, self[item])
def __iter__(self):
for item in super(PointDict, self).__iter__():
yield self.decode(item)
在查找很多要点时,使用矢量化批量写入/查找的纯粹numpy解决方案可能会更好。但是,该解决方案易于理解和实施。
答案 1 :(得分:0)
使用2元组的浮点数作为键,而不是numpy数组。元组是可以清除的,因为它们是不可变的。
Python词典在后台使用hash-table来快速进行键查找。
编写closeto
函数并不难;
def closeto(a, b, limit=0.1):
x, y = a
p, q = b
return (x-p)**2 + (y-q)**2 < limit**2
这可以用来找到接近的点。但是你必须迭代所有键,因为键查找是准确的。但是如果你在理解中进行这种迭代,它比它for
- 循环要快得多。
测试(在IPython中,使用Python 3):
In [1]: %cpaste
Pasting code; enter '--' alone on the line to stop or use Ctrl-D.
: def closeto(a, b, limit=0.1):
: x, y = a
: p, q = b
: return (x-p)**2 + (y-q)**2 < limit**2
:--
In [2]: d = {(0.0, 0.0): 12, (1.02, 2.17): 32, (2.0, 4.2): 23}
In [3]: {k: v for k, v in d.items() if closeto(k, (1.0, 2.0), limit=0.5)}
Out[3]: {(1.02, 2.17): 32}
答案 2 :(得分:0)
将数组转换为元组, hashable:
In [18]: a1 = np.array([0.5, 0.5])
In [19]: a2 = np.array([1.0, 1.5])
In [20]: d = {}
In [21]: d[tuple(a1)] = 14
In [22]: d[tuple(a2)] = 15
In [23]: d
Out[23]: {(0.5, 0.5): 14, (1.0, 1.5): 15}
In [24]: a3 = np.array([0.5, 0.5])
In [25]: a3 in d
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-25-07c81d61b999> in <module>()
----> 1 a3 in d
TypeError: unhashable type: 'numpy.ndarray'
In [26]: tuple(a3) in d
Out[26]: True
不幸的是,由于您希望对比较应用容差,因此您没有太多选择,只能迭代查找“关闭”匹配的所有键,无论您是将其实现为函数还是内联。