我正在研究一些例子并尝试实现一个函数,该函数计算列表中有多少子集加起来给定的数字。
尝试将python中的一些实现重写为Haskell:
test1 :: [Int]
test1 = [2,4,6,10,1,4,5,6,7,8]
countSets1 total input = length [n | n <- subsets $ sort input, sum n == total]
where
subsets [] = [[]]
subsets (x:xs) = map (x:) (subsets xs) ++ subsets xs
countSets2 total input = go (reverse . sort $ input) total
where
go [] _ = 0
go (x:xs) t
| t == 0 = 1
| t < 0 = 0
| t < x = go xs t
| otherwise = go xs (t - x) + go xs t
countSets3 total input = go (sort input) total (length input - 1)
where
go xxs t i
| t == 0 = 1
| t < 0 = 0
| i < 0 = 0
| t < (xxs !! i) = go xxs t (i-1)
| otherwise = go xxs (t - (xxs !! i)) (i-1) + go xxs t (i-1)
我无法弄清楚为什么countSets2
不会返回与countSets3
(python版本的副本)相同的结果
λ: countSets1 16 test1
24
λ: countSets2 16 test1
13
λ: countSets3 16 test1
24
编辑: @freestyle指出我的条件顺序在两个解决方案中是不同的:
countSets2 total input = go (sortBy (flip compare) input) total
where
go _ 0 = 1
go [] _ = 0
go (x:xs) t
| t < 0 = 0
| t < x = go xs t
| otherwise = go xs (t - x) + go xs t
解决了这个问题。
答案 0 :(得分:0)
我不确定你的逻辑,但在你的第二个解决方案中,我认为你需要
go [] 0 = 1
否则,您的代码会导致go [] 0 = 0
感觉不对。
答案 1 :(得分:0)
我不会对您的错误进行处理,所以我不希望您接受我的回答。我只提供一个解决方案:
import Math.Combinat.Sets (sublists)
getSublists :: [Int] -> Int -> [[Int]]
getSublists list total = filter (\x -> sum x == total) (sublists list)
countSublists :: [Int] -> Int -> Int
countSublists list total = length $ getSublists list total
模块Math.Combinat.Sets
来自combinat
包。
>>> countSublists [2,4,6,10,1,4,5,6,7,8] 16
24
答案 2 :(得分:0)
This problem looks similar to a pearl written by Richard Bird on how many sums and products can make 100. I'll use it as a template here. First, the specification:
subseqn :: (Num a, Eq a) => a -> [a] -> Int
subseqn n = length . filter ((== n) . sum) . subseqs
where
subseqs = foldr prefix [[]]
prefix x xss = map (x:) xss ++ xss
Observe that a lot of work may be wasted in subseqs
. Intuitively, we can discard candidates as soon as they exceed n, i.e. use the weaker predicate (<= n)
somewhere. Trivially, filtering on it before filtering on the stronger one does not change the outcome. Then you can derive
filter ((== n) . sum) . subseqs
= {- insert weaker predicate -}
filter ((== n) . sum) . filter ((<= n) . sum) . subseqs
= {- definition of subseqs -}
filter ((== n) . sum) . filter ((<= n) . sum) . foldr prefix [[]]
= {- fusion law of foldr -}
filter ((== n) . sum) . foldr prefix' [[]]
The fusion law states that f . foldr g a = foldr h b
iff
Here, a = b = [[]]
, f is filter ((<= n) . sum)
and g is prefix
. You can derive h (i.e. prefix'
) by observing that the predicate can be applied before prefixing:
filter ((<= n) . sum) (prefix x xss) =
filter ((<= n) . sum) (prefix x (filter ((<= n) . sum) xss))
which is exactly the third condition; then h is filter ((<= n) . sum) . prefix
.
Another observation is that sum
is computed too many times. To get around that, we can modify our definition of subseqn
so that each candidate carries its own sum. Let's use
(&&&) :: (a -> b) -> (a -> c) -> a -> (b, c)
(&&&) f g x = (f x, g x)
and derive
filter ((== n) . sum) . subseqs
= {- use &&& -}
filter ((== n) . snd) . map (id &&& sum) . subseqs
= {- definition of subseqs -}
filter ((== n) . snd) . map (id &&& sum) . foldr prefix' [[]]
= {- fusion law of foldr -}
filter ((== n) . snd) . foldr prefix'' [[]]
I won't go through the whole derivation of prefix''
, it is quite long. The gist is that you can avoid using sum
at all by working on pairs, so that the sum is computed iteratively. Initially the sum is 0 for the empty list and all we have to do is add the new candidate to it.
We update our base case from [[]]
to [([], 0)]
and get:
prefix'' x = filter ((<= n) . snd) . uncurry zip . (prefix x *** add x) . unzip
where
(***) :: (a -> a') -> (b -> b') -> (a, b) -> (a', b')
(***) f g (x, y) = (f x, g y)
add :: Num a => a -> [a] -> [a]
add x xs = map (x+) xs ++ xs
Here is the final version:
subseqn :: (Num a, Ord a) => a -> [a] -> Int
subseqn n = length . filter ((== n) . snd) . foldr expand [([], 0)]
where
expand x = filter ((<= n) . snd) . uncurry zip . (prefix x *** add x) . unzip
prefix x xss = map (x:) xss ++ xss
add x xs = map (x+) xs ++ xs
(***
and &&&
are from Control.Arrow)