如何将这个基于列表的代码转换为使用可变数组?

时间:2012-04-02 21:21:20

标签: arrays haskell

EDIT3:我正在编写代码来处理Int的非常长的输入列表,只有几百个非重复项。我使用两个辅助列表来维持累积的部分和来计算一些累加器值,它们是如何以及为什么不重要。我想在这里抛弃所有列表并将其转换为良好的破坏性循环,我不知道如何。我不需要整个代码,只是一个骨架代码会很棒,对两个辅助数组进行读/写操作并返回一些最终结果。我现在拥有0.5小时的输入。我现在用C ++编写了这个代码,它在相同的输入下运行90秒。


我根本无法理解如何做到这一点。这是我现在拥有的基于列表的代码:(但下面的基于地图的代码更清晰)

ins :: (Num b, Ord a) => a -> b -> [(a, b)] -> ([(a, b)], b)
ins n x [] = ( [(n,x)], 0) 
ins n x l@((v, s):t) = 
  case compare n v of
    LT -> ( (n,s+x) : l , s )
    EQ -> ( (n,s+x) : t , if null t then 0 else snd (head t))
    GT -> let (u,z) = ins n x t
          in  ((v,s+x):u,z)

这用于循环,用于处理已知长度的数字列表(现在将其更改为foldl)

scanl g (0,([],[])) ns  -- ns :: [Int]
g ::
  (Num t, Ord t, Ord a) =>
  (t, ([(a, t)], [(a, t)])) -> a -> (t, ([(a, t)], [(a, t)])) 
g (c,( a, b)) n = 
    let
      (a2,x) = ins n 1 a
      (b2,y) = if x>0 then ins n x b else (b,0)
      c2     = c + y
    in
      (c2,( a2, b2))

这有效,但我需要加快速度。在C中,我会将列表(a,b)保留为数组;使用二进制搜索来查找具有恰好高于或等于n的密钥的元素(而不是此处使用的顺序搜索);并使用就地更新来更改前面的所有条目。

我只对最终价值感兴趣。在Haskell中如何使用可变数组完成此操作?

我尝试了一些东西,但我真的不知道我在这里做了什么,而且我收到了奇怪而且很长的错误信息(比如“不能从上下文中推断......”):

goarr top = runSTArray $ do
  let sz = 10000
  a <- newArray (1,sz) (0,0) :: ST s (STArray s Int (Integer,Integer))
  b <- newArray (1,sz) (0,0) :: ST s (STArray s Int (Integer,Integer))
  let p1 = somefunc 2 -- somefunc :: Integer -> [(Integer, Int)]
  go1 p1 2 0 top a b

go1 p1 i c top a b = 
    if i >= top
     then 
      do
       return c
     else
       go2 p1 i c top a b

go2 p1 i c top a b =
  do
   let p2 = somefunc (i+1)  -- p2 :: [(Integer, Int)]
   let n  = combine p1 p2   -- n :: Int
   -- update arrays and calc new c 
   -- like the "g" function is doing:
   --    (a2,x) = ins n 1 a
   --    (b2,y) = if x>0 then ins n x b else (b,0)
   --    c2     = c + y
   go1 p2 (i+1) c2 top a b  -- a2 b2??

这根本不起作用。我甚至不知道如何用符号编码循环。请帮忙。

UPD:运行速度慢3倍的基于地图的代码:

ins3 :: (Ord k, Num a) => k -> a -> Map.Map k a -> (Map.Map k a, a)
ins3 n x a | Map.null a = (Map.insert n x a , 0)
ins3 n x a = let (p,q,r) = Map.splitLookup n a in
  case q of 
    Nothing -> (Map.union (Map.map (+x) p) 
                 (Map.insert n (x+leftmost r) r) , leftmost r)
    Just s -> (Map.union (Map.map (+x) p) 
                 (Map.insert n (x+s) r) , leftmost r)

leftmost r | Map.null r = 0
           | otherwise = snd . head $ Map.toList r

UPD2:错误消息是“无法推断(Num(STArray s1 ie))来自context.h的文字`0'引起的上下文():417:11”< / p>

这就是return c函数中go1所说的位置。也许c可能是一个数组,但我想返回使用两个辅助数组时构建的累加器值。


EDIT3:我已根据Chris的建议将scanl(!!)替换为foldltake,现在它以常量运行具有明智的经验复杂性的空间,实际上预计在0.5小时内完成 - aot ... 3天!我当然知道这件事,但确信GHC会为我优化这些东西,肯定不会产生那么大的差别,我想!因此感觉只有可变阵列可以帮助......真可惜。

尽管如此,C ++在90秒内也是如此,我非常感谢在Haskell中学习如何使用可变数组进行编码的帮助。

2 个答案:

答案 0 :(得分:3)

输入值是否均为EQ?如果它们不是EQ,那么使用scanl g (0,([],[])) ns的方式意味着第一个[(,)]数组,在a的每个阶段都称其为map snd a == reverse [1..length a] g。例如,在长度为10的列表中,snd (a !! 4)的值将为10-4。通过改变a中每个前面条目的第二个值来保持这些反向索引值是非常浪费的。如果你需要速度,那么这是制作更好算法的地方。

这些都不适用于第二个[(,)],其目的对我来说仍然是神秘的。它记录了a末尾未完成的所有插入,因此可能允许重建初始值序列。

你说“我只对最终价值感兴趣。”你的意思是你只关心scanl ..行列表输出中的最后一个值吗?如果是,那么您需要foldl而不是scanl

编辑:我正在使用自定义手指树添加一个不可变的解决方案。它通过我的临时测试(在代码底部):

{-# LANGUAGE MultiParamTypeClasses #-}
import Data.Monoid
import Data.FingerTree

data Entry a v = E !a !v deriving Show

data ME a v = NoF | F !(Entry a v) deriving Show

instance Num v => Monoid (ME a v) where
  mempty = NoF
  NoF `mappend` k = k
  k `mappend` NoF = k
  (F (E _a1 v1)) `mappend` (F (E a2 v2)) = F (E a2 (v1 + v2))

instance Num v => Measured (ME a v) (Entry a v) where
  measure = F

type M a v = FingerTree (ME a v) (Entry a v)

getV NoF = 0
getV (F (E _a v)) = v

expand :: Num v => M a v -> [(a, v)]
expand m = case viewl m of
             EmptyL -> []
             (E a _v) :< m' -> (a, getV (measure m)) : expand m'

ins :: (Ord a, Num v) => a -> v -> M a v -> (M a v, v)
ins n x m =
  let comp (F (E a _)) = n <= a
      comp NoF = False
      (lo, hi) = split comp m
  in case viewl hi of
       EmptyL -> (lo |> E n x, 0)
       (E v s) :< higher | n < v ->
         (lo >< (E n x <| hi), getV (measure hi))
                         | otherwise ->
         (lo >< (E n (s+x) <| higher), getV (measure higher))

g :: (Num t, Ord t, Ord a) =>
     (t, (M a t, M a t)) -> a -> (t, (M a t, M a t))
g (c, (a, b)) n =
  let (a2, x) = ins n 1 a
      (b2, y) = if x>0 then ins n x b else (b, 0)
  in (c+y, (a2, b2))

go :: (Ord a, Num v, Ord v) => [a] -> (v, ([(a, v)], [(a, v)]))
go ns = let (t, (a, b)) = foldl g (0, (mempty, mempty)) ns
        in (t, (expand a, expand b))

up = [1..6]
down = [5,4..1]
see'tests = map go [ up, down, up ++ down, down ++ up ]

main = putStrLn . unlines . map show $ see'test

答案 1 :(得分:2)

略显不正统,我正在使用可变技术添加第二个答案。由于user1308992提到了Fenwick树,我用它们来实现算法。在运行期间分配和变异两个STUArray。基本的Fenwick树保留所有较小索引的总数,这里的算法需要所有较大索引的总计。此更改由(sz-x)减法处理。

import Control.Monad.ST(runST,ST)
import Data.Array.ST(STUArray,newArray)
import Data.Array.Base(unsafeRead, unsafeWrite)
import Data.Bits((.&.))
import Debug.Trace(trace)
import Data.List(group,sort)

{-# INLINE lsb #-}
lsb :: Int -> Int
lsb i = (negate i) .&. i

go :: [Int] -> Int
go xs = compute (maximum xs) xs

-- Require "top == maximum xs" and "all (>=0) xs"
compute :: Int -> [Int] -> Int
compute top xs = runST mutating where
  -- Have (sz - (top+1)) > 0 to keep algorithm simple
  sz = top + 2

  -- Reversed Fenwick tree (no bounds checking)
  insert :: STUArray s Int Int -> Int -> Int -> ST s ()
  insert arr x v = loop (sz-x) where
    loop i | i > sz = return ()
           | i <= 0 = error "wtf"
           | otherwise = do
      oldVal <- unsafeRead arr i
      unsafeWrite arr i (oldVal + v)
      loop (i + lsb i)

  getSum :: STUArray s Int Int -> Int -> ST s Int
  getSum arr x = loop (sz - x) 0 where
     loop i acc | i <= 0 = return acc
                | otherwise = do
       val <- unsafeRead arr i
       loop (i - lsb i) $! acc + val

  ins n x arr = do
    insert arr n x
    getSum arr (succ n)

  mutating :: ST s Int
  mutating = do
    -- Start index from 0 to make unsafeRead, unsafeWrite easy
    a <- newArray (0,sz) 0 :: ST s (STUArray s Int Int)
    b <- newArray (0,sz) 0 :: ST s (STUArray s Int Int)
    let loop [] c = return c
        loop (n:ns) c = do
          x <- ins n 1 a
          y <- if x > 0
               then 
                 ins n x b
               else
                 return 0
          loop ns $! c + y
    -- Without debugging use the next line
    -- loop xs 0
    -- With debugging use the next five lines
    c <- loop xs 0
    a' <- see a
    b' <- see b
    trace (show (c,(a',b'))) $ do 
    return c

  -- see is only used in debugging
  see arr = do
    let zs = map head . group . sort $ xs
    vs <- sequence [ getSum arr z | z <- zs ]
    let ans = filter (\(a,v) -> v>0) (zip zs vs)
    return ans

up = [1..6]
down = [5,4..1]
see'tests = map go [ up, down, up ++ down, down ++ up ]

main = putStrLn . unlines . map show $ see'tests