在Haskell中高效实现Fisher精确测试

时间:2018-02-18 20:25:29

标签: haskell functional-programming

我正在尝试在Haskell中实现Fisher's exact test,因此给定四个自然数a,b,c和d,我想计算公式:

p =((a + b)!*(a + c)!*(b + d)!*(c + d)!)/(a!* b!* c!* d!*(a) + b + C + d)!)

我尝试了3种实现,但需要更高效的实现:

解决方案1:

module Main where

import Data.Ratio

factori n = fact_acc n 1

fact_acc 0 a = a
fact_acc n a = fact_acc (n-1) $! (n*a)

a = 1
b = 9
c = 7
d = 3

n1 = (factori (a+b)) `div` (factori a)
n2 = (factori (a+c)) `div` (factori c)
n3 = (factori (b+d)) `div` (factori b)
n4 = (factori (c+d)) `div` (factori d)
numer = n1 * n2 * n3 * n4
denom = factori (a+b+c+d)
p = (fromIntegral numer) / (fromIntegral denom)

main = do
    print denom 
    print p

解决方案2(对不起排长队):

module Main where

factori n = fact_acc n 1
fact_acc 0 a = a
fact_acc n a = fact_acc (n-1) $! (n*a)

mul_from_to m n = mul_acc m n 1
mul_acc m n a = if (m==n) then (n*a) else mul_acc (m+1) n $! (m*a)

compute_p a b c d
     | ((a+b)>(a+c) && (a+b)>(b+d) && (a+b)>(c+d) && a<b && a<c && a<d) = fromRational (fromIntegral ((mul_from_to (c+1) (a+c)) * (mul_from_to (b+1) (b+d)) * (mul_from_to (d+1) (c+d))) / fromIntegral ((factori a) * (mul_from_to (a+b+1) (a+b+c+d))))
     | ((a+b)>(a+c) && (a+b)>(b+d) && (a+b)>(c+d) && b<c && b<d)        = fromRational (fromIntegral ((mul_from_to (a+1) (a+c)) * (mul_from_to (d+1) (b+d)) * (mul_from_to (c+1) (c+d))) / fromIntegral ((factori b) * (mul_from_to (a+b+1) (a+b+c+d))))
     | ((a+b)>(a+c) && (a+b)>(b+d) && (a+b)>(c+d) && c<d)               = fromRational (fromIntegral ((mul_from_to (a+1) (a+c)) * (mul_from_to (b+1) (b+d)) * (mul_from_to (d+1) (c+d))) / fromIntegral ((factori c) * (mul_from_to (a+b+1) (a+b+c+d))))
     | ((a+b)>(a+c) && (a+b)>(b+d) && (a+b)>(c+d))                      = fromRational (fromIntegral ((mul_from_to (a+1) (a+c)) * (mul_from_to (b+1) (b+d)) * (mul_from_to (c+1) (c+d))) / fromIntegral ((factori d) * (mul_from_to (a+b+1) (a+b+c+d))))
     | ((a+c)>(b+d) && (a+c)>(c+d) && a<b && a<c && a<d)                = fromRational (fromIntegral ((mul_from_to (b+1) (a+b)) * (mul_from_to (d+1) (b+d)) * (mul_from_to (c+1) (c+d))) / fromIntegral ((factori a) * (mul_from_to (a+c+1) (a+b+c+d))))
     | ((a+c)>(b+d) && (a+c)>(c+d) && b<c && b<d)                       = fromRational (fromIntegral ((mul_from_to (a+1) (a+b)) * (mul_from_to (d+1) (b+d)) * (mul_from_to (c+1) (c+d))) / fromIntegral ((factori b) * (mul_from_to (a+c+1) (a+b+c+d))))
     | ((a+c)>(b+d) && (a+c)>(c+d) && c<d)                              = fromRational (fromIntegral ((mul_from_to (a+1) (a+b)) * (mul_from_to (b+1) (b+d)) * (mul_from_to (d+1) (c+d))) / fromIntegral ((factori c) * (mul_from_to (a+c+1) (a+b+c+d))))
     | ((a+c)>(b+d) && (a+c)>(c+d))                                     = fromRational (fromIntegral ((mul_from_to (a+1) (a+b)) * (mul_from_to (b+1) (b+d)) * (mul_from_to (c+1) (c+d))) / fromIntegral ((factori d) * (mul_from_to (a+c+1) (a+b+c+d))))
     | ((b+d)>(c+d) && a<b && a<c && a<d)                               = fromRational (fromIntegral ((mul_from_to (b+1) (a+b)) * (mul_from_to (c+1) (a+c)) * (mul_from_to (d+1) (c+d))) / fromIntegral ((factori a) * (mul_from_to (b+d+1) (a+b+c+d))))
     | ((b+d)>(c+d) && b<c && b<d)                                      = fromRational (fromIntegral ((mul_from_to (a+1) (a+b)) * (mul_from_to (c+1) (a+c)) * (mul_from_to (d+1) (c+d))) / fromIntegral ((factori b) * (mul_from_to (b+d+1) (a+b+c+d))))
     | ((b+d)>(c+d) && c<d)                                             = fromRational (fromIntegral ((mul_from_to (b+1) (a+b)) * (mul_from_to (a+1) (a+c)) * (mul_from_to (d+1) (c+d))) / fromIntegral ((factori c) * (mul_from_to (b+d+1) (a+b+c+d))))
     | ((b+d)>(c+d))                                                    = fromRational (fromIntegral ((mul_from_to (b+1) (a+b)) * (mul_from_to (a+1) (a+c)) * (mul_from_to (c+1) (c+d))) / fromIntegral ((factori d) * (mul_from_to (b+d+1) (a+b+c+d))))
     | (a<b && a<c && a<d)                                              = fromRational (fromIntegral ((mul_from_to (b+1) (a+b)) * (mul_from_to (c+1) (a+c)) * (mul_from_to (d+1) (b+d))) / fromIntegral ((factori a) * (mul_from_to (c+d+1) (a+b+c+d))))
     | (b<c && b<d)                                                     = fromRational (fromIntegral ((mul_from_to (a+1) (a+b)) * (mul_from_to (c+1) (a+c)) * (mul_from_to (d+1) (b+d))) / fromIntegral ((factori b) * (mul_from_to (c+d+1) (a+b+c+d))))
     | (c<d)                                                            = fromRational (fromIntegral ((mul_from_to (b+1) (a+b)) * (mul_from_to (a+1) (a+c)) * (mul_from_to (d+1) (b+d))) / fromIntegral ((factori c) * (mul_from_to (c+d+1) (a+b+c+d))))
     | otherwise                                                        = fromRational (fromIntegral ((mul_from_to (a+1) (a+b)) * (mul_from_to (c+1) (a+c)) * (mul_from_to (b+1) (b+d))) / fromIntegral ((factori d) * (mul_from_to (c+d+1) (a+b+c+d))))

a = 50000
b = 910
c = 11
d = 300

p = compute_p a b c d

main = do
  print p

解决方案3:

module Main where

import Data.Ratio

factorial n = factorials !! pred n
factorials = scanl1 (\acc x -> acc * x) [1..maxim]

a = 1
b = 9
c = 7
d = 3

maxim=a+b+c+d

n1 = (factorial (a+b)) `div` (factorial a)
n2 = (factorial (a+c)) `div` (factorial c)
n3 = (factorial (b+d)) `div` (factorial b)
n4 = (factorial (c+d)) `div` (factorial d)

numer = n1 * n2 * n3 * n4
denom = factorial (a+b+c+d)

p = (fromIntegral numer) / (fromIntegral denom)

main = do
  print denom
  print p

1 个答案:

答案 0 :(得分:6)

你计算

factorial (a+b) `div` factorial a

多次,ab的值不同。只需将aa+b之间的数字相乘即可改善这一点;这减少了乘法的总数,避免了完全划分,所以应该帮助一些。

根据比例,树形折叠而不是严格的左折可以提高进行大量乘法的性能(因为乘以大约相同幅度的数字比乘以一个大数字和一个小数字更有效)。像这样:

foldb' :: (a -> a -> a) -> a -> [a] -> a
foldb' f z = go where
    go [] = z
    go [v] = v
    go long = go (adjacent long)

    adjacent (x:y:rest) = let !h = f x y in h : adjacent rest
    adjacent short = short

然后,您可以使用foldb' (*) 1来比计算递归更快地计算产品。

但我认为这两项改进将非常微不足道。他们肯定不是渐近的改进。 (更新:在我的测试中,使用树形折叠实际上是一个非常大的胜利:factorial 100000需要943毫秒foldl',18毫秒foldb',50倍加速。)