覆盖字典行为Python3

时间:2015-08-18 11:36:53

标签: python python-3.x numpy

我是一个使用Python的初学者,我试图在字典中使用搜索功能来搜索具有点的坐标(2)的numpy数组的键。所以,我想要的是:一个字典,其键是numpy数组,其值是整数。然后使用in运算符来使用一些容差度量(numpy.allclose函数)来比较键。我知道numpy数组不是hashables所以我必须覆盖 getitem setitem 函数(基于我在How to properly subclass dict and override __getitem__ & __setitem__中找到的)。但是如何让这些可以在字典中添加它们作为键呢?在这种情况下,如何覆盖in运算符的行为?

感谢您的帮助!

3 个答案:

答案 0 :(得分:1)

Numpy数组不可清除但是元组是。因此,如果将数组转换为元组,则可以对数组进行散列。从理论上讲,如果你事先对它进行舍入,你可以利用快速查找,因为你现在有了离散点。但是你会在重新翻译时遇到分辨率问题,因为舍入是用十进制基数完成的,但数字是存储二进制的。可以通过将其转换为缩放的整数来避免这种情况,但这会使所有内容减慢一些。

最后,你只需要编写一个在数组和元组之间来回转换的类,你就可以了。你很高兴。 实现可能如下所示:

import numpy as np

class PointDict(dict):

    def __init__(self, precision=5):
        super(PointDict, self).__init__()
        self._prec = 10**precision

    def decode(self, tup):
        """
        Turns a tuple that was used as index back into a numpy array.
        """
        return np.array(tup, dtype=float)/self._prec

    def encode(self, ndarray):
        """
        Rounds a numpy array and turns it into a tuple so that it can be used
        as index for this dict.
        """
        return tuple(int(x) for x in ndarray*self._prec)

    def __getitem__(self, item):
        return self.decode(super(PointDict, self).__getitem__(self.encode(item)))

    def __setitem__(self, item, value):
        return super(PointDict, self).__setitem__(self.encode(item), value)

    def __contains__(self, item):
        return super(PointDict, self).__contains__(self.encode(item))

    def update(self, other):
        for item, value in other.items():
            self[item] = value

    def items(self):
        for item in self:
            yield (item, self[item])

    def __iter__(self):
        for item in super(PointDict, self).__iter__():
            yield self.decode(item)

在查找很多要点时,使用矢量化批量写入/查找的纯粹numpy解决方案可能会更好。但是,该解决方案易于理解和实施。

答案 1 :(得分:0)

使用2元组的浮点数作为键,而不是numpy数组。元组是可以清除的,因为它们是不可变的。

Python词典在后台使用hash-table来快速进行键查找。

编写closeto函数并不难;

def closeto(a, b, limit=0.1):
    x, y = a
    p, q = b
    return (x-p)**2 + (y-q)**2 < limit**2

这可以用来找到接近的点。但是你必须迭代所有键,因为键查找是准确的。但是如果你在理解中进行这种迭代,它比它for - 循环要快得多。

测试(在IPython中,使用Python 3):

In [1]: %cpaste
Pasting code; enter '--' alone on the line to stop or use Ctrl-D.
:    def closeto(a, b, limit=0.1):
:        x, y = a
:        p, q = b
:        return (x-p)**2 + (y-q)**2 < limit**2
:--

In [2]: d = {(0.0, 0.0): 12, (1.02, 2.17): 32, (2.0, 4.2): 23}

In [3]: {k: v for k, v in d.items() if closeto(k, (1.0, 2.0), limit=0.5)}
Out[3]: {(1.02, 2.17): 32}

答案 2 :(得分:0)

将数组转换为元组, hashable:

In [18]: a1 = np.array([0.5, 0.5])

In [19]: a2 = np.array([1.0, 1.5])

In [20]: d = {}

In [21]: d[tuple(a1)] = 14

In [22]: d[tuple(a2)] = 15

In [23]: d
Out[23]: {(0.5, 0.5): 14, (1.0, 1.5): 15}

In [24]: a3 = np.array([0.5, 0.5])

In [25]: a3 in d
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-25-07c81d61b999> in <module>()
----> 1 a3 in d

TypeError: unhashable type: 'numpy.ndarray'

In [26]: tuple(a3) in d
Out[26]: True

不幸的是,由于您希望对比较应用容差,因此您没有太多选择,只能迭代查找“关闭”匹配的所有键,无论您是将其实现为函数还是内联。