Haskell中嵌套的三角形循环?

时间:2011-06-20 13:19:13

标签: haskell

我在Java中有以下内容,它基本上是一个嵌套的三角形循环:

    int n = 10;
    B bs[] = new B[n];
    // some initial values, bla bla
    double dt = 0.001;
    for (int i = 0; i < n; i++) {
        bs[i] = new B();
        bs[i].x = i * 0.5;
        bs[i].v = i * 2.5;
        bs[i].m = i * 5.5;
    }
    for (int i = 0; i < n; i++) {
        for (int j = **(i+1)**; j < n; j++) {
            double d = bs[i].x - bs[j].x;

            double sqr = d * d + 0.01;
            double dist = Math.sqrt(sqr);
            double mag = dt / (sqr * dist);

            bs[i].v -= d * bs[j].m * mag;
            **bs[j].v += d * bs[i].m * mag;**
        }
    }   

    // printing out the value v
    for (int i = 0; i < n; i++) {
        System.out.println(bs[i].v);
    }

B组:

class B {
    double x, v, m;
}

在每次迭代中,数组的索引i和j处的值同时更新,从而避免执行完整的嵌套循环。以下给出了相同的结果,但它完成了一个完整的嵌套循环(请原谅我正在使用的术语,它们可能不正确,但我希望它确实有意义。)

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            double d = bs[i].x - bs[j].x;

            double sqr = d * d + 0.01;
            double dist = Math.sqrt(sqr);
            double mag = dt / (sqr * dist);

            bs[i].v -= d * bs[j].m * mag;
        }
    }

注意: 与上一代码相比,唯一的变化是int j = 0; NOT int j = (i+1);并已移除bs[j].v += d * bs[i].m * mag;

我想在Haskell中做同样的事情,但很难正确地思考它。我有以下代码。 Haskell版本中的数组表示为一个列表(xs),我已将其初始化为0。

n = 20
xs = replicate n 0

update = foldl' (update') xs [0..(n-1)]
    where
        update' i = update'' i (i+1) []
        update'' i j acc
            | j == n = acc
            | otherwise = new_acc
                where
                    new_acc = result:acc
                    result = ...do something

我将对n有很大的价值,例如1000,5000等 当n = 1000时,完整的嵌套循环给出length [(i,j)|i<-[0..1000],j<-[0..1000]] = 1002001但三角形版本给出length [(i,j)|i<-[0..1000],j<-[(i+1)..1000]] = 500500。在Haskell中做2个地图很容易让它完成循环,但我想要三角形版本。我想这意味着保持对列表中i和j的更改,然后在最后更新原始列表?任何想法将不胜感激。感谢

3 个答案:

答案 0 :(得分:4)

这是使用来自vector包的未装箱的可变载体的直接翻译。代码有点难看,但应该非常快:

module Main
    where

import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M

numElts :: Int
numElts = 10

dt :: Double
dt = 0.001

loop :: Int -> M.IOVector Double -> M.IOVector Double 
        -> M.IOVector Double -> IO ()
loop n x v m = go 0
  where
    doWork i j = do xI <- M.read x i
                    xJ <- M.read x j
                    vI <- M.read v i
                    vJ <- M.read v j
                    mI <- M.read m i
                    mJ <- M.read m j

                    let d = xI - xJ
                    let sqr = d * d + 0.01
                    let dist = sqrt sqr
                    let mag = dt / (sqr * dist)

                    M.write v i (vI - d * mJ * mag)
                    M.write v j (vJ + d * mI * mag)

    go i | i < n     = do go' (i+1)
                          go  (i+1)
         | otherwise = return ()
      where
        go' j | j < n     = do doWork i j
                               go' (j + 1)
              | otherwise = return ()

main :: IO ()
main = do x <- generateVector 0.5
          v <- generateVector 2.5
          m <- generateVector 5.5
          loop numElts x v m
          v' <- U.unsafeFreeze v
          U.forM_ v' print

    where
      generateVector :: Double -> IO (M.IOVector Double)
      generateVector d = do v <- M.new numElts
                            generateVector' numElts d v
                            return v

      generateVector' :: Int -> Double -> M.IOVector Double -> IO ()
      generateVector' n d v = go 0
        where
          go i | i < n = do M.unsafeWrite v i (fromIntegral i * d)
                            go (i+1)
               | otherwise = return ()

更新:关于“非常快”的声明:我benchmarked my solution反对Federico提供的纯粹声明并得到以下结果(n = 1000):< / p>

benchmarking pureSolution
collecting 100 samples, 1 iterations each, in estimated 334.5483 s
mean: 2.949640 s, lb 2.867693 s, ub 3.005429 s, ci 0.950
std dev: 421.1978 ms, lb 343.8233 ms, ub 539.4906 ms, ci 0.950
found 4 outliers among 100 samples (4.0%)
  3 (3.0%) high severe
variance introduced by outliers: 5.997%
variance is slightly inflated by outliers

benchmarking pureVectorSolution
collecting 100 samples, 1 iterations each, in estimated 280.4593 s
mean: 2.747359 s, lb 2.709507 s, ub 2.803392 s, ci 0.950
std dev: 237.7489 ms, lb 179.3110 ms, ub 311.8813 ms, ci 0.950
found 13 outliers among 100 samples (13.0%)
  7 (7.0%) high mild
  6 (6.0%) high severe
variance introduced by outliers: 2.998%
variance is slightly inflated by outliers

benchmarking imperativeSolution
collecting 100 samples, 1 iterations each, in estimated 5.905104 s
mean: 58.59154 ms, lb 56.79405 ms, ub 60.60033 ms, ci 0.950
std dev: 11.70101 ms, lb 9.120100 ms, ub NaN s, ci 0.950

所以必要的解决方案是约。比功能性的快50倍(当所有内容都适合缓存时,差异对于较小的n更为显着)。我试图使Federico的解决方案与未装箱的载体一起工作,但显然它以一种关键的方式依赖于懒惰,这使得未装箱的版本永远循环。 “纯矢量”版本使用了盒装矢量。

答案 1 :(得分:3)

我不确定这会解决你的问题,因为我还没有完全掌握它,但三角形循环本身在Haskell中很容易做到:

triangularLoop :: (a -> a -> b) -> [a] -> [b]
triangularLoop f xs = do
    (x1 : t) <- tails xs
    x2 <- t
    return $ f x1 x2

或者,没有monadic语法编写,

triangularLoop f = concat . map singlePass . tails
    where
        singlePass [] = []
        singlePass (h:t) = map (f h) t

答案 2 :(得分:2)

在Haskell中编写嵌套循环的典型惯用方法是使用列表推导。

以下是我翻译代码的方法:

import Data.Array

import Data.List (tails)

data Body = Body {x::Double,v::Double,m::Double}
            deriving Show

n::Int
n = 9

dt::Double
dt = 0.001

bs_0 :: Array Int Body
bs_0 = array (0,n) [(i,Body {x = i'*0.5,v = i'*2.5,m = i'*5.5}) | 
                    i <- [0..n], let i' = fromIntegral i]

bs :: Array Int Body
bs = accum (\b dv -> b {v = v b + dv}) bs_0 dvs
     where 
       dvs :: [(Int,Double)]
       dvs = concat [[(i,dv_i),(j,dv_j)] | (i:is) <- tails [0..n], 
                                            j <- is,
                                            let d = x(bs!i) - x(bs!j)
                                                sqr =  d * d + 0.01
                                                dist = sqrt sqr
                                                mag = dt / (sqr * dist)
                                                dv_i = -d * m(bs!j) * mag
                                                dv_j =  d * m(bs!i) * mag]

main :: IO()
main = mapM_ print (assocs bs)