编辑Haskell中的距离算法 - 性能调优

时间:2011-04-01 14:50:52

标签: performance haskell

我正在尝试在Haskell中实现levenshtein距离(或编辑距离),但是当字符串长度增加时,它的性能会迅速下降。

我仍然是Haskell的新手,所以如果你能就我如何改进算法给我一些建议会很好。我已经尝试“预先计算”值(inits),但由于它没有改变任何东西,我还原了那个改变。

我知道Hackage上已经有editDistance实现,但是我需要它来处理任意标记列表,而不是字符串。另外,我发现它有点复杂,至少与我的版本相比。

所以,这是代码:

-- standard levenshtein distance between two lists
editDistance      :: Eq a => [a] -> [a] -> Int
editDistance s1 s2 = editDistance' 1 1 1 s1 s2 

-- weighted levenshtein distance
-- ins, sub and del are the costs for the various operations
editDistance'      :: Eq a => Int -> Int -> Int -> [a] -> [a] -> Int
editDistance' _ _ ins s1 [] = ins * length s1 
editDistance' _ _ ins [] s2 = ins * length s2 
editDistance' del sub ins s1 s2  
    | last s1 == last s2 = editDistance' del sub ins (init s1) (init s2)
    | otherwise          = minimum [ editDistance' del sub ins s1 (init s2)        + del -- deletion 
                                   , editDistance' del sub ins (init s1) (init s2) + sub -- substitution
                                   , editDistance' del sub ins (init s1) s2        + ins -- insertion
                                   ]

这似乎是一个正确的实现,至少它提供与此online tool完全相同的结果。

提前感谢您的帮助!如果您需要任何其他信息,请告诉我们。

问候, BZN

6 个答案:

答案 0 :(得分:20)

忽略这是一个糟糕的算法(应该记忆,我到达那一秒)......

使用O(1)基元而非O(n)

一个问题是你对列表使用O(n)的整串调用(haskell列表是单链表)。更好的数据结构将为您提供O(1)操作,我使用Vector

import qualified Data.Vector as V

-- standard levenshtein distance between two lists
editDistance      :: Eq a => [a] -> [a] -> Int
editDistance s1 s2 = editDistance' 1 1 1 (V.fromList s1) (V.fromList s2)

-- weighted levenshtein distance
-- ins, sub and del are the costs for the various operations
editDistance'      :: Eq a => Int -> Int -> Int -> V.Vector a -> V.Vector a -> Int
editDistance' del sub ins s1 s2
  | V.null s2 = ins * V.length s1
  | V.null s1 = ins * V.length s2
  | V.last s1 == V.last s2 = editDistance' del sub ins (V.init s1) (V.init s2)
  | otherwise            = minimum [ editDistance' del sub ins s1 (V.init s2)        + del -- deletion 
                                   , editDistance' del sub ins (V.init s1) (V.init s2) + sub -- substitution
                                   , editDistance' del sub ins (V.init s1) s2        + ins -- insertion
                                   ]

列表的O(n)操作包括init,lengthlast(尽管init至少可以是懒惰的)。所有这些操作都是O(1)使用Vector。

虽然真正的基准测试应该使用Criterion,这是一个快速而肮脏的基准:

str2 = replicate 15 'a' ++ replicate 25 'b'
str1 = replicate 20 'a' ++ replicate 20 'b'
main = print $ editDistance str1 str2

显示矢量版本需要0.09秒而字符串需要1.6秒,因此我们甚至没有查看您的editDistance算法就节省了大约一个数量级。

现在怎样才能记住结果?

更大的问题显然是需要记忆。我把这作为学习monad-memo包的机会 - 我的上帝真棒!对于一个额外的约束(你需要Ord a),你基本上没有努力得到一个memoization。代码:

import qualified Data.Vector as V
import Control.Monad.Memo

-- standard levenshtein distance between two lists
editDistance      :: (Eq a, Ord a) => [a] -> [a] -> Int
editDistance s1 s2 = startEvalMemo $ editDistance' (1, 1, 1, (V.fromList s1), (V.fromList s2))

-- weighted levenshtein distance
-- ins, sub and del are the costs for the various operations
editDistance' :: (MonadMemo (Int, Int, Int, V.Vector a, V.Vector a) Int m, Eq a) => (Int, Int, Int, V.Vector a, V.Vector a) -> m Int
editDistance' (del, sub, ins, s1, s2)
  | V.null s2 = return $ ins * V.length s1
  | V.null s1 = return $ ins * V.length s2
  | V.last s1 == V.last s2 = memo editDistance' (del, sub, ins, (V.init s1), (V.init s2))
  | otherwise = do
        r1 <- memo editDistance' (del, sub, ins, s1, (V.init s2))
        r2 <- memo editDistance' (del, sub, ins, (V.init s1), (V.init s2))
        r3 <- memo editDistance' (del, sub, ins, (V.init s1), s2)
        return $ minimum [ r1 + del -- deletion 
                         , r2 + sub -- substitution
                         , r3 + ins -- insertion
                                   ]

你看到memoization如何需要一个“key”(参见MonadMemo类)?我将所有参数打包成一个丑陋的大元组。它还需要一个“值”,即您生成的Int。然后它只是使用“备忘录”功能即插即用,即可记忆您想要的值。

对于基准测试,我使用了更短但更大距离的字符串:

$ time ./so  # the memoized vector version
12

real    0m0.003s

$ time ./so3  # the non-memoized vector version
12

real    1m33.122s

甚至不考虑运行非memoized字符串版本,我认为它至少需要大约15分钟。至于我,我现在喜欢monad-memo - 感谢Eduard这个包!

编辑:StringVector之间的差异在备忘录版本中没有那么多,但是当距离达到200左右时仍然会增长到2倍,所以仍值得。

编辑:也许我应该解释为什么更大的问题是“显然”记忆结果。好吧,如果你看一下原始算法的核心:

 [ editDistance' ... s1          (V.init s2)  + del 
 , editDistance' ... (V.init s1) (V.init s2) + sub
 , editDistance' ... (V.init s1) s2          + ins]

很明显,调用editDistance' s1 s2会导致3次调用editDistance' ...每次调用editDistance'三次......还有三次调用......春天来了!指数爆炸!幸运的是,大多数电话都是相同的!例如(使用-->表示“来电”而eD表示editDistance'):

eD s1 s2  --> eD s1 (init s2)             -- The parent
            , eD (init s1) s2
            , eD (init s1) (init s2)
eD (init s1) s2 --> eD (init s1) (init s2)         -- The first "child"
                  , eD (init (init s1)) s2
                  , eD (init (init s1)) (init s2) 
eD s1 (init s2) --> eD s1 (init (init s2))
                  , eD (init s1) (init s2)
                  , eD (init s1) (init (init s2))

通过考虑父母和两个直接孩子,我们可以看到呼叫ed (init s1) (init s2)已完成三次。另一个孩子也与父母分享电话,所有孩子彼此分享许多电话(和他们的孩子,提示Monty Python短剧)。

使用runMemo之类的函数来返回所使用的缓存结果的数量,这将是一个有趣的,也许是有益的练习。

答案 1 :(得分:5)

你需要记住editDistance'。有很多种方法可以做到这一点,例如,递归定义的数组。

答案 2 :(得分:2)

如前所述,memoization就是你所需要的。另外,您正在查看从右到左的编辑距离,对于字符串来说效率不高,无论方向如何,编辑距离都是相同的。那就是:editDistance (reverse a) (reverse b) == editDistance a b

为了解决memoization部分,有很多库可以帮助你。在下面的例子中,我选择了MemoTrie,因为它非常易于使用并且在这里表现良好。

import Data.MemoTrie(memo2)

editDistance' del sub ins = memf
  where
   memf = memo2 f
   f s1     []     = ins * length s1
   f []     s2     = ins * length s2
   f (x:xs) (y:ys)
     | x == y  = memf xs ys
     | otherwise = minimum [ del + memf xs (y:ys),
                             sub + memf (x:xs) ys,
                             ins + memf xs ys]

正如您所看到的,您只需添加备忘录即可。其余的是相同的,除了我们从列表的开头开始,直到最后。

答案 3 :(得分:1)

  

我知道Hackage上已经有editDistance实现,但是我需要它来处理任意标记列表,不一定是字符串

是否有有限数量的令牌?我建议您尝试简单地设计从标记到字符的映射。毕竟还有10,646 characters at your disposal

答案 4 :(得分:1)

这个版本比那些记忆版本快得多,但我仍然希望它能更快。 100字符长字符串可以正常工作。 我写的是其他距离(更改init函数和成本),并使用经典的动态编程数组技巧。 长行可以转换为顶部'do'的单独函数,但我喜欢这种方式。

import Data.Array.IO
import System.IO.Unsafe

editDistance = dist ini med

dist :: (Int -> Int -> Int) -> (a -> a -> Int ) -> [a] -> [a] -> Int
dist i f a b  = unsafePerformIO $ distM i f a b

-- easy to create other distances 
ini i 0 = i
ini 0 j = j
ini _ _ = 0
med a b = if a == b then 0 else 2


distM :: (Int -> Int -> Int) -> (a -> a -> Int) -> [a] -> [a] -> IO Int
distM ini f a b = do
        let la = length a
        let lb = length b

        arr <- newListArray ((0,0),(la,lb)) [ini i j | i<- [0..la], j<-[0..lb]] :: IO (IOArray (Int,Int) Int)

-- all on one line
        mapM_ (\(i,j) -> readArray arr (i-1,j-1) >>= \ld -> readArray arr (i-1,j) >>= \l -> readArray arr (i,j-1) >>= \d-> writeArray arr (i,j) $ minimum [l+1,d+1, ld + (f (a !! (i-1) ) (b !! (j-1))) ] ) [(i,j)| i<-[1..la], j<-[1..lb]]

        readArray arr (la,lb)

答案 5 :(得分:1)

人们建议您使用通用的memoization库,但是对于定义Levenshtein距离的简单任务,简单的动态编程就足够了。 一个非常简单的基于多态列表的实现:

distance s t = 
    d !!(length s)!!(length t) 
    where d = [ [ dist m n | n <- [0..length t] ] | m <- [0..length s] ]
          dist i 0 = i
          dist 0 j = j
          dist i j = minimum [ d!!(i-1)!!j+1
                             , d!!i!!(j-1)+1
                             , d!!(i-1)!!(j-1) + (if s!!(i-1)==t!!(j-1) 
                                                  then 0 else 1) 
                             ]

或者如果你需要长序列的真正速度,你可以使用一个可变数组:

import Data.Array
import qualified Data.Array.Unboxed as UA
import Data.Array.ST
import Control.Monad.ST


-- Mutable unboxed and immutable boxed arrays
distance :: Eq a => [a] -> [a] -> Int
distance s t = d UA.! (ls , lt)
    where s' = array (0,ls) [ (i,x) | (i,x) <- zip [0..] s ]
          t' = array (0,lt) [ (i,x) | (i,x) <- zip [0..] t ]
          ls = length s
          lt = length t
          (l,h) = ((0,0),(length s,length t))
          d = runSTUArray $ do
                m <- newArray (l,h) 0 
                for_ [0..ls] $ \i -> writeArray m (i,0) i
                for_ [0..lt] $ \j -> writeArray m (0,j) j
                for_ [1..lt] $ \j -> do
                              for_ [1..ls] $ \i -> do
                                  let c = if s'!(i-1)==t'! (j-1) 
                                          then 0 else 1
                                  x <- readArray m (i-1,j)
                                  y <- readArray m (i,j-1)
                                  z <- readArray m (i-1,j-1)
                                  writeArray m (i,j) $ minimum [x+1, y+1, z+c ]
                return m

for_ xs f =  mapM_ f xs