避免递归中的模式匹配

时间:2013-06-01 08:21:50

标签: haskell applicative

考虑一下我用来解决欧拉问题58的代码:

diagNums = go skips 2
    where go (s:skips) x = let x' = x+s
                           in x':go skips (x'+1)

squareDiagDeltas = go diagNums
    where go xs = let (h,r) = splitAt 4 xs
                  in h:go r

我不喜欢第二个函数中的模式匹配。它看起来比必要的复杂!这对我来说经常出现。在这里,splitAt返回一个元组,所以我必须先解构它才能递归。当我的递归本身返回我想要修改的元组时,相同的模式可能更令人讨厌。考虑:

f n = go [1..n]
    where go [] = (0,0)
          go (x:xs) = let (y,z) = go xs
                      in (y+x, z-x)

与简单的递归相比:

f n = go [1..n]
    where go [] = 0
          go (x:xs) = x+go xs

当然这里的功能纯属无稽之谈,可以用完全不同的更好的方式编写。但我的观点是,每次我需要通过递归回调多个值时,就会出现模式匹配的需要。

有没有办法避免这种情况,可能是使用Applicative或类似的东西?或者你会认为这种风格是惯用的吗?

3 个答案:

答案 0 :(得分:6)

首先,这种风格实际上是惯用的。由于你对两个不同的值做两件事,所以有一些不可简化的复杂性;实际的模式匹配本身并没有多少介绍。此外,我个人发现大多数时候显式风格非常易读。

然而,还有另一种选择。 Control.Arrow有许多用于处理元组的函数。由于函数箭头->也是Arrow,所有这些都适用于正常函数。

因此,您可以使用(***)重写第二个示例,以组合两个函数来处理元组。此运算符具有以下类型:

(***) :: a b c -> a b' c' -> a (b, b') (c, c')

如果我们将a替换为->,我们会得到:

(***) :: (b -> c) -> (b' -> c') -> ((b, b') -> (c, c'))

因此,您可以将(+ x)(- x)合并为(+ x) *** (- x)的单个函数。这相当于:

\ (a, b) -> (a + x, b - x)

然后你可以在你的递归中使用它。不幸的是,-运算符是愚蠢的,并且不能分段工作,所以你必须用lambda编写它:

(+ x) *** (\ a -> a - x) $ go xs 

你很明显可以想象使用任何其他运算符,所有非常愚蠢:)。

老实说,我认为这个版本的可读性低于原版。但是,在其他情况下,***版本可以更具可读性,因此了解它是有用的。特别是,如果您将(+ x) *** (- x)传递给更高阶函数而不是立即应用它,我认为***版本会比显式lambda更好。

答案 1 :(得分:4)

我同意Tikhon Jelvis的说法,你的版本没有任何问题。就像他说的那样,使用Control.Arrow中的组合器可以用于更高阶函数。您可以使用折叠编写f

f n = foldr (\x -> (+ x) *** subtract x) (0,0) [1..n]

如果你真的想摆脱let中的squareDiagDeltas(我不确定我会),你可以使用second,因为你只是在修改第二个元组的元素:

squareDiagDeltas = go diagNums
  where go = uncurry (:) . second go . splitAt 4

答案 2 :(得分:4)

我同意hammarunfoldr is the way to go here

您还可以摆脱diagNums中的模式匹配:

diagNums = go skips 2
    where go (s:skips) x = let x' = x+s
                           in x':go skips (x'+1)

递归使得有点难以分辨这里发生了什么,所以让我们来吧 深入研究。

假设skips = s0 : s1 : s2 : s3 : ...,我们有:

diagNums = go skips 2
         = go (s0 : s1 : s2 : s3 : ...) 2 
         = s0+2 : go (s1 : s2 : s3 : ... ) (s0+3)
         = s0+2 : s0+s1+3 : go (s2 : s3 : ... ) (s0+s1+4) 
         = s0+2 : s0+s1+3 : s0+s1+s2+4 : go (s3 : ... ) (s0+s1+s2+5) 
         = s0+2 : s0+s1+3 : s0+s1+s2+4 : s0+s1+s2+s3+5 : go (...) (s0+s1+s2+s3+6) 

这使得更清楚的是,我们得到了两个序列的总和,使用zipWith (+)很容易计算:

diagNums = zipWith (+) [2,3,4,5,...] [s0, s0+s1, s0+s1+s2, s0+s1+s2+s3,...] 

所以现在我们只需找到一种更好的方法来计算skips的部分和,这对scanl1非常有用:

scanl1 (+) skips = s0 : s0+s1 : s0+s1+s2 : s0+s1+s2+s3 : ...

让(IMO)更容易理解diagNums的定义:

diagNums = zipWith (+) [2..] $ scanl1 (+) skips