SQL用于计算几个向量的Tanimoto系数

时间:2010-01-02 17:15:15

标签: python sql collaborative-filtering

我认为通过一个例子来解释我的问题更容易。

我有一个食谱配料表,我已经实现了一个函数来计算成分之间的Tanimoto coefficient。它足够快以计算两个成分之间的系数(需要3个sql查询),但它不能很好地扩展。要计算所有可能成分组合之间的系数,它需要N +(N *(N-1))/ 2个查询或50000个查询仅1k成分。有更快的方法吗?这是我到目前为止所得到的:

class Filtering():
  def __init__(self):
    self._connection=sqlite.connect('database.db')

  def n_recipes(self, ingredient_id):
    cursor = self._connection.cursor()
    cursor.execute('''select count(recipe_id) from recipe_ingredient
        where ingredient_id = ? ''', (ingredient_id, ))
    return cursor.fetchone()[0]

  def n_recipes_intersection(self, ingredient_a, ingredient_b):
    cursor = self._connection.cursor()
    cursor.execute('''select count(drink_id) from recipe_ingredient where
        ingredient_id = ? and recipe_id in (
        select recipe_id from recipe_ingredient
        where ingredient_id = ?) ''', (ingredient_a, ingredient_b))
    return cursor.fetchone()[0]

  def tanimoto(self, ingredient_a, ingredient_b):
    n_a, n_b = map(self.n_recipes, (ingredient_a, ingredient_b))
    n_ab = self.n_recipes_intersection(ingredient_a, ingredient_b)
    return float(n_ab) / (n_a + n_b - n_ab)

4 个答案:

答案 0 :(得分:4)

为什么不简单地将所有食谱提取到内存中然后在内存中计算Tanimoto系数?

它更简单,而且更快,更快。

答案 1 :(得分:3)

如果有人感兴趣,这是我在Alex和S.Lotts的建议之后提出的代码。谢谢你们。

def __init__(self):
    self._connection=sqlite.connect('database.db')
    self._counts = None
    self._intersections = {}

def inc_intersections(self, ingredients):
    ingredients.sort()
    lenght = len(ingredients)
    for i in xrange(1, lenght):
        a = ingredients[i]
        for j in xrange(0, i):
            b = ingredients[j]
            if a not in self._intersections:
                self._intersections[a] = {b: 1}
            elif b not in self._intersections[a]:
                self._intersections[a][b] = 1
            else:
                self._intersections[a][b] += 1


def precompute_tanimoto(self):
    counts = {}
    self._intersections = {}

    cursor = self._connection.cursor()
    cursor.execute('''select recipe_id, ingredient_id
        from recipe_ingredient
        order by recipe_id, ingredient_id''')
    rows = cursor.fetchall()            

    print len(rows)

    last_recipe = None
    for recipe, ingredient in rows:
        if recipe != last_recipe:
            if last_recipe != None:
                self.inc_intersections(ingredients)
            last_recipe = recipe
            ingredients = [ingredient]
        else:
            ingredients.append(ingredient)

        if ingredient not in counts:
            counts[ingredient] = 1
        else:
            counts[ingredient] += 1

    self.inc_intersections(ingredients)

    self._counts = counts

def tanimoto(self, ingredient_a, ingredient_b):
    if self._counts == None:
        self.precompute_tanimoto()

    if ingredient_b > ingredient_a:
        ingredient_b, ingredient_a = ingredient_a, ingredient_b

    n_a, n_b = self._counts[ingredient_a], self._counts[ingredient_b]
    n_ab = self._intersections[ingredient_a][ingredient_b]

    print n_a, n_b, n_ab

    return float(n_ab) / (n_a + n_b - n_ab)

答案 2 :(得分:1)

如果您有1000种成分,1000个查询就足以将每种成分映射到内存中的一组食谱。如果(比方说)一个成分通常是大约100个配方的一部分,每个集合将需要几KB,所以整个字典只需要几MB - 完全没有问题将整个事物保存在内存中(并且仍然不是很严重)如果每种成分的平均配方数量增加了一个数量级,则会出现内存问题。

result = dict()
for ing_id in all_ingredient_ids:
    cursor.execute('''select recipe_id from recipe_ingredient
        where ingredient_id = ?''', (ing_id,))
    result[ing_id] = set(r[0] for r in cursor.fetchall())
return result

在这1000次查询之后,成对Tanimoto系数所需的500,000个计算中的每一个显然都在内存中完成 - 您可以预先计算各个集合的长度的平方作为进一步的加速(并将它们停放在另一个中) dict),每对的关键“A dotproduct B”组件当然是集合交叉点的长度。

答案 3 :(得分:0)

我认为这会减少每对交叉点的2个选择,每对总共4个查询。你无法摆脱O(N ^ 2),因为你正在尝试所有对 - N *(N-1)/ 2只是有多少对。

def n_recipes_intersection(self, ingredient_a, ingredient_b):
  cursor = self._cur
  cursor.execute('''
    select count(recipe_id)
      from recipe_ingredient as A 
        join recipe_ingredient as B using (recipe_id)
      where A.ingredient_id = ? 
        and B.ingredient_id = ?;
      ''', (ingredient_a, ingredient_b))
  return cursor.fetchone()[0]