记住[Integer]类型的函数 - >一个

时间:2015-01-07 16:14:50

标签: haskell memoization

我的问题是如何有效地记忆为所有有限整数列表定义的昂贵函数f :: [Integer] -> a并具有属性f . sort = f

我的典型用例是,给定一个整数列表as,我需要为各种整数a获取值f (a:as),所以我想同时建立一个有顶点的有向标记图是一对整数列表及其函数值。当且仅当a:as = bs时,存在由from(as,f as)到(bs,f bs)标记的边。

brilliant answer by Edward Kmett偷窃我只是复制了

{-# LANGUAGE BangPatterns #-}
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

并将他的想法改编为我的问题

-- directed graph labelled by Integers
data Graph a = Graph a (Tree (Graph a))
instance Functor Graph where
  fmap f (Graph a t) = Graph (f a) (fmap (fmap f) t)

-- walk the graph following the given labels
walk :: Graph a -> [Integer] -> a
walk (Graph a _) [] = a
walk (Graph _ t) (x:xs) = walk (index t x) xs

-- graph of all finite integer sequences
intSeq :: Graph [Integer]
intSeq = Graph [] (fmap (\n -> fmap (n:) intSeq) nats)

-- could be replaced by Data.Strict.Pair
data StrictPair a b = StrictPair !a !b
  deriving Show

-- f = sum modified according to Edward's idea (the real function is more complicated)
g :: ([Integer] -> StrictPair Integer [Integer]) -> [Integer] -> StrictPair Integer [Integer]
g mf [] = StrictPair 0 []
g mf (a:as) = StrictPair (a+x) (a:as)
  where StrictPair x y = mf as

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) intSeq

g_m :: [Integer] -> StrictPair Integer [Integer]
g_m = walk g_graph

这样可以正常工作,但由于函数f与出现的整数的顺序无关(但不是它们的计数),因此在图中只有一个顶点,所有整数列表都等于排序。

我如何实现这一目标?

4 个答案:

答案 0 :(得分:2)

如何定义g_m' = g_m . sort,即在调用memoized函数之前先简单地对输入列表进行排序?

我觉得这是你能做的最好的事情,因为如果你想让你的memoized图只包含有条件的路径,那么在构造路径之前,某人必须要查看列表中的所有元素。

根据您的输入列表的样子,以一种使树枝减少的方式转换它们可能会有所帮助。例如,您可以尝试排序并采取差异:

original input list:   [8,3,14,8,5]
sorted:                [3,3,8,8,14]
diffed:                [3,0,5,0,6] -- use this as the key

转换是一种双射,树的分支较少,因为涉及的数量较少。

答案 1 :(得分:2)

您可以使用不同的方法。 有证据证明可数集的有限积是可数的:

我们可以按[a1, ..., an]Nat将序列product . zipWith (^) primes映射到2 ^ a1 * 3 ^ a2 * 5 ^ a3 * ... * primen ^ an

为了避免最后的序列为零的问题,我们可以增加最后一个索引。

由于序列是有序的,我们可以将该属性用作user5402提到的。

使用树的好处是,您可以增加分支以加速遍历。 OTOH主要技巧可以使索引相当大,但希望一些树路径只是未开发的(保持为thunk)。

{-# LANGUAGE BangPatterns #-}

-- Modified from Kmett's answer:
data Tree a = Tree a (Tree a) (Tree a) (Tree a) (Tree a)
instance Functor Tree where
  fmap f (Tree x a b c d) = Tree (f x) (fmap f a) (fmap f b) (fmap f c) (fmap f d)

index :: Tree a -> Integer -> a
index (Tree x _ _ _ _) 0 = x
index (Tree _ a b c d) n = case (n - 1) `divMod` 4 of
  (q,0) -> index a q
  (q,1) -> index b q
  (q,2) -> index c q
  (q,3) -> index d q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree n (go a s') (go b s') (go c s') (go d s')
            where
                a = n + s
                b = a + s
                c = b + s
                d = c + s
                s' = s * 4

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- Primes -- https://www.haskell.org/haskellwiki/Prime_numbers
-- Generation and factorisation could be done much better
minus (x:xs) (y:ys) = case (compare x y) of
           LT -> x : minus  xs  (y:ys)
           EQ ->     minus  xs     ys
           GT ->     minus (x:xs)  ys
minus  xs     _     = xs

primes = 2 : sieve [3..] primes
  where
    sieve xs (p:ps) | q <- p*p , (h,t) <- span (< q) xs =
                   h ++ sieve (t `minus` [q, q+p..]) ps

addToLast :: [Integer] -> [Integer]
addToLast [] = []
addToLast [x] = [x + 1]
addToLast (x:xs) = x : addToLast xs

subFromLast :: [Integer] -> [Integer]
subFromLast [] = []
subFromLast [x] = [x - 1]
subFromLast (x:xs) = x : subFromLast xs

addSubProp :: [NonNegative Integer] -> Property
addSubProp xs = xs' === subFromLast (addToLast xs')
  where xs' = map getNonNegative xs

-- Trick from user5402 answer
toDiffList :: [Integer] -> [Integer]
toDiffList = toDiffList' 0
  where toDiffList' _ [] = []
        toDiffList' p (x:xs) = x - p : toDiffList' x xs

fromDiffList :: [Integer] -> [Integer]
fromDiffList = fromDiffList' 0
  where fromDiffList' _ [] = []
        fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs

diffProp :: [Integer] -> Property
diffProp xs = xs === fromDiffList (toDiffList xs)

listToInteger :: [Integer] -> Integer
listToInteger = product . zipWith (^) primes . addToLast

integerToList :: Integer -> [Integer]
integerToList = subFromLast . impl primes 0
  where impl _      _ 0 = []
        impl _      0 1 = []
        impl _      k 1 = [k]
        impl (p:ps) k n = case n `divMod` p of
                            (n', 0) -> impl (p:ps) (k + 1) n'
                            (_,  _) -> k : impl ps 0 n

listProp :: [NonNegative Integer] -> Property
listProp xs = xs' === integerToList (listToInteger xs')
  where xs' = map getNonNegative xs

toIndex :: [Integer] -> Integer
toIndex = listToInteger . toDiffList

fromIndex :: Integer -> [Integer]
fromIndex = fromDiffList . integerToList

-- [1,0] /= [0]
-- Decreasing sequence!
doesntHold :: [NonNegative Integer] -> Property
doesntHold xs = xs' === fromIndex (toIndex xs')
  where xs' = map getNonNegative xs

holds :: [NonNegative Integer] -> Property
holds xs = xs' === fromIndex (toIndex xs')
  where xs' = sort $ map getNonNegative xs

g :: ([Integer] -> Integer) -> [Integer] -> Integer
g mg = g' . sort
  where g' [] = 0
        g' (x:xs)  = x + sum (map mg $ tails xs)

g_tree :: Tree Integer
g_tree = fmap (g faster_g' . fromIndex) nats

faster_g' :: [Integer] -> Integer
faster_g' = index g_tree . toIndex

faster_g = faster_g' . sort

fix g [1..22]仍然快速燃烧时,我的机器faster_g [1..40]感觉很慢。


添加如果我们有界限集(索引 0..n-1 ),我们可以将其编码为:{{ 1}}。

我们可以将任何a0 * n^0 + a1 * n^1 ...编码为二进制列表,例如Integer11(排名第一)。 然后,如果我们将列表中的整数与[1, 1, 0, 1]分开,我们就会得到有界值的序列。

作为奖励,我们可以使用例如 0,1,2 数字和压缩的序列来使用例如二进制数。霍夫曼编码,因为2比0或1少得多。但这可能是过度的。

通过这个技巧,索引可以保持更小的空间并且空间可能更好。

2

第二次补充:

我快速为我的{-# LANGUAGE BangPatterns #-} -- From Kment's answer: import Data.Function (fix) import Data.List (sort, tails) import Data.List.Split (splitOn) import Test.QuickCheck {-- Tree definition as before --} -- 0, 1, 2 newtype N3 = N3 { unN3 :: Integer } deriving (Eq, Show) instance Arbitrary N3 where arbitrary = elements $ map N3 [ 0, 1, 2 ] -- Integer <-> N3 coeffs3 :: [Integer] coeffs3 = coeffs' 1 where coeffs' n = n : coeffs' (n * 3) listToInteger :: [N3] -> Integer listToInteger = sum . zipWith f coeffs3 where f n (N3 m) = n * m listFromInteger :: Integer -> [N3] listFromInteger 0 = [] listFromInteger n = case n `divMod` 3 of (q, m) -> N3 m : listFromInteger q listProp :: [N3] -> Property listProp xs = (null xs || last xs /= N3 0) ==> xs === listFromInteger (listToInteger xs) -- Integer <-> N2 -- 0, 1 newtype N2 = N2 { unN2 :: Integer } deriving (Eq, Show) coeffs2 :: [Integer] coeffs2 = coeffs' 1 where coeffs' n = n : coeffs' (n * 2) integerToBin :: Integer -> [N2] integerToBin 0 = [] integerToBin n = case n `divMod` 2 of (q, m) -> N2 m : integerToBin q integerFromBin :: [N2] -> Integer integerFromBin = sum . zipWith f coeffs2 where f n (N2 m) = n * m binProp :: NonNegative Integer -> Property binProp (NonNegative n) = n === integerFromBin (integerToBin n) -- unsafe! n3ton2 :: N3 -> N2 n3ton2 = N2 . unN3 n2ton3 :: N2 -> N3 n2ton3 = N3 . unN2 -- [Integer] <-> [N3] integerListToN3List :: [Integer] -> [N3] integerListToN3List = concatMap (++ [N3 2]) . map (map n2ton3 . integerToBin) integerListFromN3List :: [N3] -> [Integer] integerListFromN3List = init . map (integerFromBin . map n3ton2) . splitOn [N3 2] n3ListProp :: [NonNegative Integer] -> Property n3ListProp xs = xs' === integerListFromN3List (integerListToN3List xs') where xs' = map getNonNegative xs -- Trick from user5402 answer -- Integer <-> Sorted Integer toDiffList :: [Integer] -> [Integer] toDiffList = toDiffList' 0 where toDiffList' _ [] = [] toDiffList' p (x:xs) = x - p : toDiffList' x xs fromDiffList :: [Integer] -> [Integer] fromDiffList = fromDiffList' 0 where fromDiffList' _ [] = [] fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs diffProp :: [Integer] -> Property diffProp xs = xs === fromDiffList (toDiffList xs) --- toIndex :: [Integer] -> Integer toIndex = listToInteger . integerListToN3List . toDiffList fromIndex :: Integer -> [Integer] fromIndex = fromDiffList . integerListFromN3List . listFromInteger -- [1,0] /= [0] -- Decreasing sequence! doesn't terminate in this case doesntHold :: [NonNegative Integer] -> Property doesntHold xs = xs' === fromIndex (toIndex xs') where xs' = map getNonNegative xs holds :: [NonNegative Integer] -> Property holds xs = xs' === fromIndex (toIndex xs') where xs' = sort $ map getNonNegative xs g :: ([Integer] -> Integer) -> [Integer] -> Integer g mg = g' . sort where g' [] = 0 g' (x:xs) = x + sum (map mg $ tails xs) g_tree :: Tree Integer g_tree = fmap (g faster_g' . fromIndex) nats faster_g' :: [Integer] -> Integer faster_g' = index g_tree . toIndex faster_g = faster_g' . sort 基准图和二进制序列方法:

g

结果是:

main :: IO ()
main = do
  n <- read . head <$> getArgs
  print $ faster_g [100, 110..n]

通过常数因子 2 看起来图表版本更快。但它们似乎具有相同的时间复杂度:)

答案 2 :(得分:1)

通过简单地用单调版本替换intSeq定义中的g_graph,我的问题就解决了:

-- replace vertexes for non-monotone integer lists by the according monotone one
monoIntSeq :: Graph [Integer]
monoIntSeq = f intSeq
  where f (Graph as t) | as == sort as = Graph as $ fmap f t
                       | otherwise     = fetch monIntSeq $ sort as

-- extract the subgraph after following the given labels
fetch :: Graph a -> [Integer] -> Graph a
fetch g [] = g
fetch (Graph _ t) (x:xs) = fetch (index t x) xs

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) monoIntSeq

非常感谢所有人(特别是user5402和Oleg)的帮助!


编辑:我仍然遇到一个问题,即我的典型用例的内存消耗量很高,可以通过以下路径来描述:

p :: [Integer]
p = map f [1..]
  where f n | n `mod` 6 == 0 = n `div` 6
            | n `mod` 3 == 0 = n `div` 3
            | n `mod` 2 == 0 = n `div` 2
            | otherwise      = n

略微改进是直接定义单调整数序列:

-- extract the subgraph after following the given labels (right to left)
fetch :: Graph a -> [Integer] -> Graph a
fetch = foldl' step
  where step (Graph _ t) n = index t n

-- walk the graph following the given labels (right to left)
walk :: Graph a -> [Integer] -> a
walk g ns = a
  where Graph a _ = fetch g ns

-- all monotone falling integer sequences
monoIntSeqs :: Graph [Integer]
monoIntSeqs = Graph [] $ fmap (flip f monoIntSeqs) nats
  where f n (Graph ns t) | null ns      = Graph (n:ns) $ fmap (f n) t
                         | n >= head ns = Graph (n:ns) $ fmap (f n) t
                         | otherwise    = fetch monoIntSeqs (insert' n ns)
        insert' = insertBy (comparing Down)

但最后我可能只使用没有标识的原始整数序列,现在明确地识别节点,避免保留对g_graph等的引用,以便在程序进行时清理垃圾收集。

答案 3 :(得分:0)

阅读Richard Bird和Ralf Hinze撰写的功能珍珠Trouble Shared is Trouble Halved,我明白了如何实现,两年前我想要的东西(再次基于Edward Kmett的伎俩):

{-# LANGUAGE BangPatterns #-}
import Data.Function (fix)

data Tree a = Tree (Tree a) a (Tree a)
  deriving Show

instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

data IntSeqTree a = IntSeqTree a (Tree (IntSeqTree a))

val :: IntSeqTree a -> a
val (IntSeqTree a _) = a

step :: Integer -> IntSeqTree t -> IntSeqTree t
step n (IntSeqTree _ ts) = index ts n

intSeqTree :: IntSeqTree [Integer]
intSeqTree = fix $ create []
  where create p x = IntSeqTree p $ fmap (extend x) nats
        extend x n = case span (>n) (val x) of
                       ([], p) -> fix $ create (n:p)
                       (m, p)  -> foldr step intSeqTree (m ++ n:p)

instance Functor IntSeqTree where
  fmap f (IntSeqTree a t) = IntSeqTree (f a) (fmap (fmap f) t)

在我的用例中,我有数百或数千个类似的整数序列(长度为几百个条目),这些序列是递增生成的。所以对我来说,这种方式比在查找函数值之前对序列进行排序要便宜(我将通过在intSeqTree上使用fmap来访问它)。