用于列表处理的Haskell优化受到懒惰评估

时间:2016-08-04 20:40:18

标签: algorithm haskell

我正在努力提高以下代码的效率。我想在给定点之前计算符号的所有出现次数(作为使用Burrows-Wheeler变换的模式匹配的一部分)。我如何计算符号有一些重叠。但是,当我试图实现看起来应该是更高效的代码时,结果效率会降低,而且我假设懒惰的评估和我对它的不了解是应该受到责备。

我对计数功能的第一次尝试是这样的:

count :: Ord a => [a] -> a -> Int -> Int
count list sym pos = length . filter (== sym) . take pos $ list

然后在匹配函数本身中:

matching str refCol pattern = match 0 (n - 1) (reverse pattern)
  where n = length str
        refFstOcc sym = length $ takeWhile (/= sym) refCol
        match top bottom [] = bottom - top + 1
        match top bottom (sym : syms) =
          let topCt = count str sym top
              bottomCt = count str sym (bottom + 1)
              middleCt = bottomCt - topCt
              refCt = refFstOcc sym
          in if middleCt > 0
               then match (refCt + topCt) (refCt + bottomCt - 1) syms
               else 0

(为了简洁起见 - 我正在通过Map记住refCol中第一次出现的符号,以及其他一些细节)。

修改:示例使用将是:

matching "AT$TCTAGT" "$AACGTTTT" "TCG"

应为1(假设我没有错误输入任何内容)。

现在,我正在重述top指针和bottom两次中间的所有内容,当我计算一百万个字符的DNA字符串,只有4个可能的字符选择时,这就加起来了(和剖析告诉我,这也是一个很大的瓶颈,我将48%的时间用于bottomCt,大约38%的时间用于topCt)。作为参考,当计算一百万个字符串并尝试匹配50个模式(每个模式在1到1000个字符之间)时,程序运行大约需要8.5到9.5秒。

但是,如果我尝试实现以下功能:

countBetween :: Ord a => [a] -> a -> Int -> Int -> (Int, Int)
countBetween list sym top bottom =
  let (topList, bottomList) = splitAt top list
      midList = take (bottom - top) bottomList
      getSyms = length . filter (== sym)
  in (getSyms topList, getSyms midList)

(对匹配函数进行了更改以进行补偿),程序运行需要18到22秒。

我也试过传入一个可以跟踪之前调用的Map,但是这也需要大约20秒的时间来运行并运行内存使用。

同样,我已将length . filter (== sym)缩短为fold,但同样为foldr缩短为20秒,foldl为14-15。

那么通过重写来优化这段代码的Haskell方法是什么呢? (具体来说,我正在寻找一些不涉及预计算的东西 - 我可能不会重复使用字符串 - 这解释了为什么会发生这种情况)。

修改:更清楚的是,我要找的是以下内容:

a)为什么在Haskell中会发生这种行为?惰性求值如何发挥作用,编译器对重写countcountBetween函数进行了哪些优化,以及可能涉及的其他因素?

b)什么是简单的代码重写,可以解决这个问题,这样我就不会多次遍历列表了?我正在寻找能够解决这个问题的东西,而不是一个可以避开它的解决方案。如果最终答案是,count是编写代码的最有效方法,为什么会这样?

2 个答案:

答案 0 :(得分:1)

似乎match例程的要点是 将间隔(bottom,top)转换为另一个间隔 基于当前符号sym。公式是 基本上是:

ref_fst = index of sym in ref_col
  -- defined in an outer scope

match :: Char -> (Int,Int) -> (Int,Int)
match sym (bottom, top) | bottom > top =  (bottom, top) -- if the empty interval
match sym (bottom, top) =
  let 
    top_count = count of sym in str from index 0 to top
    bot_count = count of sym in str from index 0 to bottom
    mid_count = top_count - bot_count
  in if mid_count > 0
         then (ref_fst + bot_count, ref_fst + top_count)
         else (1,0)  -- the empty interval

然后使用matchingpattern折叠起来match 初始间隔为(0, n-1)

top_countbot_count都可以有效计算 使用预先计算的查找表,下面是代码 那样做。

如果你运行test1,你会看到间隔的痕迹 通过模式中的每个符号进行转换。

注意:可能有1个错误,我已经硬编码了 ref_fst为0 - 我不确定这是如何适应的 更大的算法,但基本的想法应该是合理的。

请注意,一旦创建了counts向量 不再需要索引到原始字符串。 因此,即使我在这里使用ByteString (较大的)DNA序列,它并不重要,而且 如果传递一个String,mkCounts例程应该也能正常工作 代替。

代码也可在http://lpaste.net/174288

获取
{-# LANGUAGE OverloadedStrings #-}

import Data.Vector.Unboxed ((!))
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Unboxed.Mutable as UVM
import qualified Data.ByteString.Char8 as BS
import Debug.Trace
import Text.Printf
import Data.List

mkCounts :: BS.ByteString -> UV.Vector (Int,Int,Int,Int)
mkCounts syms = UV.create $ do
  let n = BS.length syms
  v <- UVM.new (n+1)
  let loop x i | i >= n = return x
      loop x i = let s = BS.index syms i
                     (a,t,c,g) = x
                     x' = case s of
                            'A' -> (a+1,t,c,g)
                            'T' -> (a,t+1,c,g)
                            'C' -> (a,t,c+1,g)
                            'G' -> (a,t,c,g+1)
                            _   -> x
                 in do UVM.write v i x
                       loop x' (i+1) 
  x <- loop (0,0,0,0) 0
  UVM.write v n x
  return v

data DNA = A | C | T | G
  deriving (Show)

getter :: DNA -> (Int,Int,Int,Int) -> Int
getter A (a,_,_,_) = a
getter T (_,t,_,_) = t
getter C (_,_,c,_) = c
getter G (_,_,_,g) = g

-- narrow a window
narrow :: Int -> UV.Vector (Int,Int,Int,Int) -> DNA -> (Int,Int) ->  (Int,Int)

narrow refcol counts sym (lo,hi) | trace msg False = undefined
  where msg = printf "-- lo: %d  hi: %d  refcol: %d  sym: %s  top_cnt: %d  bot_count: %d" lo hi refcol (show sym) top_count bot_count
        top_count = getter sym (counts ! (hi+1))
        bot_count = getter sym (counts ! lo)

narrow refcol counts sym (lo,hi) =
  let top_count = getter sym (counts ! (hi+1))
      bot_count = getter sym (counts ! (lo+0))
      mid_count = top_count - bot_count
  in if mid_count > 0
       then ( refcol + bot_count, refcol + top_count-1 )
       else (lo+1,lo)  -- signal an wmpty window

findFirst :: DNA -> UV.Vector (Int,Int,Int,Int)  -> Int
findFirst sym v =
  let n = UV.length v
      loop i | i >= n = n
      loop i = if getter sym (v ! i) > 0
                 then i
                 else loop (i+1)
  in loop 0

toDNA :: String -> [DNA]
toDNA str = map charToDNA str

charToDNA :: Char -> DNA
charToDNA = go
  where go 'A' = A
        go 'C' = C
        go 'T' = T
        go 'G' = G

dnaToChar A = 'A'
dnaToChar C = 'C'
dnaToChar T = 'T'
dnaToChar G = 'G'

first :: DNA -> BS.ByteString -> Int
first sym str = maybe len id (BS.elemIndex (dnaToChar sym) str)
  where len = BS.length str

test2 = do
 -- matching "AT$TCTAGT" "$AACGTTTT" "TCG"
  let str    = "AT$TCTAGT"
      refcol = "$AACGTTTT"
      syms   = toDNA "TCG"

      -- hard coded for now
      -- may be computeed an memoized
      refcol_G = 4
      refcol_C = 3
      refcol_T = 5

      counts = mkCounts str
      w0 = (0, BS.length str -1)

      w1 = narrow refcol_G counts G w0
      w2 = narrow refcol_C counts C w1
      w3 = narrow refcol_T counts T w2

      firsts = (first A refcol, first T refcol, first C refcol, first G refcol)

  putStrLn $ "firsts: " ++ show firsts

  putStrLn $ "w0: " ++ show w0
  putStrLn $ "w1: " ++ show w1
  putStrLn $ "w2: " ++ show w2
  putStrLn $ "w3: " ++ show w3
  let (lo,hi) = w3
      len = if lo <= hi then hi - lo + 1 else 0
  putStrLn $ "length: " ++ show len

matching :: BS.ByteString -> BS.ByteString -> String -> Int
matching  str refcol pattern = 
  let counts = mkCounts str
      n = BS.length str
      syms = toDNA (reverse pattern)
      firsts = (first A refcol, first T refcol, first C refcol, first G refcol)

      go (lo,hi) sym = narrow refcol counts sym (lo,hi)
        where refcol = getter sym firsts

      (lo, hi) = foldl' go (0,n-1) syms
      len = if lo <= hi then hi - lo + 1 else 0
  in len

test3 = matching "AT$TCTAGT" "$AACGTTTT" "TCG"

答案 1 :(得分:1)

我不确定懒惰的评估与代码的性能有多大关系。我认为主要的问题是使用String - 这是一个链表 - 而不是更高效的字符串类型。

请注意countBetween函数中的此次调用:

  let (topList, bottomList) = splitAt top list

将重新创建与topList含义相对应的链接链接 更多的分配。

比较splitAt与使用take n/drop n的标准基准 可以在这里找到:http://lpaste.net/174526splitAt版本是 大约慢3倍,当然,还有更多的分配。

即使你不想预先计算&#34;你可以提高的数量 只需切换到ByteString或Text就可以了解很多。

定义:

countSyms :: Char -> ByteString -> Int -> Int -> Int
countSyms sym str lo hi =
  length [ i | i <- [lo..hi], BS.index str i == sym ]

然后:

countBetween :: ByteString -> Char -> Int -> Int -> (Int,Int)
countBetween str sym top bottom = (a,b)
  where a = countSyms sym str 0 (top-1)
        b = countSyms sym str top (bottom-1)

另外,不要在大型​​列表中使用reverse - 它会重新分配 整个清单。只需反向索引ByteString / Text。

记忆计数可能会有所帮助,也可能没有帮助。这一切都取决于它是如何完成的。