如何优化BK-Tree

时间:2012-04-07 05:39:51

标签: python algorithm cython bk-tree

我在Cython中实现了一个BK-Tree。

对于一百万件物品,搜索时间太长了!这是~30秒:(

这是我的Cython代码:

# -*- coding: UTF-8 -*-

from itertools import imap
from PIL import Image

DEF MAX_TREE_POOL = 10000

cdef extern from "distances.h":
    int hamming_distance(char *a, char *b)
    enum: HASH_BITS


cdef findInTree(Node parent, Item item, int threshold):
        cdef int d
        cdef int i = 0

        cdef Node child

        cdef object childrens
        cdef object results = []
        cdef object extends = results.extend

        if parent:
            d = hamming_distance(item.hash, parent.item.hash)
            childrens = parent.childrens.get

            if d <= threshold:
                results.append((d, parent.item))

            for i in xrange(max(0, d - threshold), d + threshold + 1):
                child = childrens(i)
                if child:
                    extends(findInTree(child, item, threshold))

        return results


cdef class Item:

    cdef public unsigned int id
    cdef public object hash

    def __init__(Item self, unsigned int id, object hash):
        assert id > 0 and len(hash) == HASH_BITS

        self.id = id
        self.hash = hash

    def __str__(Item self):
        return '<Item {0}>'.format(self.id)

    def __repr__(Item self):
        return '<Item #{0} object at 0x{1}>'.format(self.id, id(self))


cdef class Node:

    cdef readonly Item item
    cdef readonly dict childrens

    def __cinit__(Node self, Item item):
        self.item = item
        self.childrens = {}

    def __repr__(Item self):
        return '<Node object at 0x{0} item {1} childrens {2}>'.format(id(self), repr(self.item), repr(self.childrens))


cdef class BKTree:

    cdef readonly Node tree
    cdef readonly unsigned int count

    def __cinit__(BKTree self):
        self.count = 0

    def addItem(BKTree self, Item item):
        cdef int w
        cdef int d

        cdef object a

        cdef Node n
        cdef Node c

        if not self.tree:
            self.tree = Node(item)
        else:
            w = 1
            c = self.tree

            a = item.hash

            while w:
                d = hamming_distance(a, c.item.hash)
                n = c.childrens.get(d)

                if n is None:
                    c.childrens[d] = Node(item)

                    # Break
                    w = 0
                else:
                    c = c.childrens[d]

        self.count += 1

        # Success, return
        return self.count

    def query(BKTree self, Item item, int threshold):
        return findInTree(self.tree, item, threshold)


cdef class BKTreePool:

    cdef list pool
    cdef readonly unsigned int count
    cdef BKTree tree

    def __cinit__(BKTreePool self):
        self.pool = []
        self.rotate()

    def addItem(BKTreePool self, Item item):
        if self.tree.count >= MAX_TREE_POOL:
            self.rotate()

        try:
            self.tree.addItem(item)
            self.count += 1
        finally:
            return self.count

    def query(BKTreePool self, Item item, int threshold):
        cdef BKTree tree
        cdef list results

        results = []

        for tree in self.pool:
            results.extend(tree.query(item, threshold))

        return results


    cdef rotate(BKTreePool self):
        self.pool.append(BKTree())
        self.tree = self.pool[-1]

distances.h

#ifndef DISTANCES_H

  #define DISTANCES_H 1
  #define HASH_BITS 16 * 16

  static int hamming_distance(char *a, char *b);
  // static int default_distance(char *a, char *b);

  static int hamming_distance(char *a, char *b) {
      unsigned int distance = 0;
      int i;

      for (i = 0; i <= HASH_BITS; i++) {
          if (a[i] != b[i]) {
              distance++;
          }
      }

      return distance;
  }

#endif

示例:

tree = BKTreePool()
tree.addItem(Item(1, '10' * 256))
tree.addItem(Item(1, '10' * 256))
....

tree.query(Item(1, '10' * 256), 5)

此树开始通过256位散列搜索重复图像。

如何优化此findInTree功能?

1 个答案:

答案 0 :(得分:1)

通过以256位(32字节)而不是256或512字节表示“256位散列”,可以节省大量内存(从而交换和/或缓存刷新)。

Python伪代码:

num_bits_set = (0, 1, 1, 2, 1, etc etc, 7, 8)
assert len(num_bits_set) == 256

def ham_diff(a, b):
    h = 0
    for p, q in zip(a, b):
        h += num_bits_set[ord(p) ^ ord(q)]
    return h