我认为通过一个例子来解释我的问题更容易。
我有一个食谱配料表,我已经实现了一个函数来计算成分之间的Tanimoto coefficient。它足够快以计算两个成分之间的系数(需要3个sql查询),但它不能很好地扩展。要计算所有可能成分组合之间的系数,它需要N +(N *(N-1))/ 2个查询或50000个查询仅1k成分。有更快的方法吗?这是我到目前为止所得到的:
class Filtering():
def __init__(self):
self._connection=sqlite.connect('database.db')
def n_recipes(self, ingredient_id):
cursor = self._connection.cursor()
cursor.execute('''select count(recipe_id) from recipe_ingredient
where ingredient_id = ? ''', (ingredient_id, ))
return cursor.fetchone()[0]
def n_recipes_intersection(self, ingredient_a, ingredient_b):
cursor = self._connection.cursor()
cursor.execute('''select count(drink_id) from recipe_ingredient where
ingredient_id = ? and recipe_id in (
select recipe_id from recipe_ingredient
where ingredient_id = ?) ''', (ingredient_a, ingredient_b))
return cursor.fetchone()[0]
def tanimoto(self, ingredient_a, ingredient_b):
n_a, n_b = map(self.n_recipes, (ingredient_a, ingredient_b))
n_ab = self.n_recipes_intersection(ingredient_a, ingredient_b)
return float(n_ab) / (n_a + n_b - n_ab)
答案 0 :(得分:4)
为什么不简单地将所有食谱提取到内存中然后在内存中计算Tanimoto系数?
它更简单,而且更快,更快。
答案 1 :(得分:3)
如果有人感兴趣,这是我在Alex和S.Lotts的建议之后提出的代码。谢谢你们。
def __init__(self):
self._connection=sqlite.connect('database.db')
self._counts = None
self._intersections = {}
def inc_intersections(self, ingredients):
ingredients.sort()
lenght = len(ingredients)
for i in xrange(1, lenght):
a = ingredients[i]
for j in xrange(0, i):
b = ingredients[j]
if a not in self._intersections:
self._intersections[a] = {b: 1}
elif b not in self._intersections[a]:
self._intersections[a][b] = 1
else:
self._intersections[a][b] += 1
def precompute_tanimoto(self):
counts = {}
self._intersections = {}
cursor = self._connection.cursor()
cursor.execute('''select recipe_id, ingredient_id
from recipe_ingredient
order by recipe_id, ingredient_id''')
rows = cursor.fetchall()
print len(rows)
last_recipe = None
for recipe, ingredient in rows:
if recipe != last_recipe:
if last_recipe != None:
self.inc_intersections(ingredients)
last_recipe = recipe
ingredients = [ingredient]
else:
ingredients.append(ingredient)
if ingredient not in counts:
counts[ingredient] = 1
else:
counts[ingredient] += 1
self.inc_intersections(ingredients)
self._counts = counts
def tanimoto(self, ingredient_a, ingredient_b):
if self._counts == None:
self.precompute_tanimoto()
if ingredient_b > ingredient_a:
ingredient_b, ingredient_a = ingredient_a, ingredient_b
n_a, n_b = self._counts[ingredient_a], self._counts[ingredient_b]
n_ab = self._intersections[ingredient_a][ingredient_b]
print n_a, n_b, n_ab
return float(n_ab) / (n_a + n_b - n_ab)
答案 2 :(得分:1)
如果您有1000种成分,1000个查询就足以将每种成分映射到内存中的一组食谱。如果(比方说)一个成分通常是大约100个配方的一部分,每个集合将需要几KB,所以整个字典只需要几MB - 完全没有问题将整个事物保存在内存中(并且仍然不是很严重)如果每种成分的平均配方数量增加了一个数量级,则会出现内存问题。
result = dict()
for ing_id in all_ingredient_ids:
cursor.execute('''select recipe_id from recipe_ingredient
where ingredient_id = ?''', (ing_id,))
result[ing_id] = set(r[0] for r in cursor.fetchall())
return result
在这1000次查询之后,成对Tanimoto系数所需的500,000个计算中的每一个显然都在内存中完成 - 您可以预先计算各个集合的长度的平方作为进一步的加速(并将它们停放在另一个中) dict),每对的关键“A dotproduct B”组件当然是集合交叉点的长度。
答案 3 :(得分:0)
我认为这会减少每对交叉点的2个选择,每对总共4个查询。你无法摆脱O(N ^ 2),因为你正在尝试所有对 - N *(N-1)/ 2只是有多少对。
def n_recipes_intersection(self, ingredient_a, ingredient_b):
cursor = self._cur
cursor.execute('''
select count(recipe_id)
from recipe_ingredient as A
join recipe_ingredient as B using (recipe_id)
where A.ingredient_id = ?
and B.ingredient_id = ?;
''', (ingredient_a, ingredient_b))
return cursor.fetchone()[0]