理解`newtype Prob`的`bind`

时间:2014-09-01 03:01:28

标签: haskell

Learn You a Haskell显示Prob newtype:

newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving Show

以下Prob的定义:

instance Functor Prob where  
    fmap f (Prob xs) = Prob $ map (\(x,p) -> (f x,p)) xs  

instance Monad Prob where
    return x = Prob [(x, 1%1)]
    p >>= f  = flatten (fmap f p)  

然后是支持功能:

flatten :: Prob (Prob a) -> Prob a
flatten = Prob . convert . getProb 

convert :: [(Prob a, Rational)] -> [(a, Rational)]
convert = concat . (map f)

f :: (Prob a, Rational) -> [(a, Rational)]
f (p, r) = map (mult r) (getProb p)

mult :: Rational -> (a, Rational) -> (a, Rational)
mult r (x, y) = (x, r*y)

我写了flatten,convert,f和mult函数,所以我对它们很满意。

然后我们将>>=应用于以下示例,涉及数据类型Coin

data Coin = Heads | Tails deriving (Show, Eq)

coin :: Prob Coin 
coin = Prob [(Heads, 1%2), (Tails, 1%2)]

loadedCoin :: Prob Coin
loadedCoin = Prob [(Heads, 1%10), (Tails, 9%10)]

LYAH说,If we throw all the coins at once, what are the odds of all of them landing tails?

flipTwo:: Prob Bool
flipTwo= do
  a <- coin       -- a has type `Coin`
  b <- loadedCoin -- similarly
  return (all (== Tails) [a,b])

致电flipTwo返回:

Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}

flipTwo可以使用>>=重写:

flipTwoBind' :: Prob Bool
flipTwoBind' = coin >>= 
                    \x -> loadedCoin   >>= 
                                       \y -> return (all (== Tails) [x,y])

我不理解return (all (== Tails) [x,y])的类型。由于它是>>=的右侧,因此其类型必须为a -> m b(其中Monad m)。

我的理解是(all (==Tails) [x,y])会返回True or False,但return如何导致上述结果:

Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}

2 个答案:

答案 0 :(得分:4)

请注意>>=运算符的RHS是lambda表达式, return的应用程序:

\y -> return (all (== Tails) [x,y])

此lambda具有预期的类型(Monad m) => a -> m b

让我们从底部构建类型:

正如您所说,all (== Tails) [x,y]会返回TrueFalse。换句话说,它的类型是Bool

现在,检查ghci中return的类型,我们看到:

Prelude> :t return
return :: Monad m => a -> m a

所以return (all (==Tails) [x,y])Monad m => m Boolean类型。

将其包含在lambda中,然后给出类型(Monad m) => a -> m Boolean

(请注意,在此过程中,编译器会推断出具体monad类型为Prob。)

您应该将return视为采用常规值并将其包装为Monad

<强>增加:

让我们分析一下

的类型
flipTwoBind' = coin >>= 
                \x -> loadedCoin   >>= 
                                   \y -> return (all (== Tails) [x,y])

我们首先注意到这里最外层的表达式是(>>=)的应用程序,其类型为:

Prelude> :t (>>=)
(>>=) :: Monad m => m a -> (a -> m b) -> m b

LHS为coin,其类型为Prob Coin,因此我们立即推断mProbaCoin。这意味着RHS必须具有某种类型Coin -> Prob b的类型b。现在让我们来看看RHS:

\x -> loadedCoin >>= \y -> return (all (== Tails) [x,y])

这里我们有一个lambda,它返回(>>=)的应用程序的结果,所以lambda有类型

(Monad m) => a -> m b

这与第一个(>>=)的应用的预期类型相匹配,因此a此处为CoinmProb

现在分析(>>=)的内部应用,我们看到它的类型被推断为

(>>=) :: Prob Coin -> (Prob -> Prob b) -> Prob b

我们已经分析了第二个(>>=)的RHS,因此b被推断为Bool

(注意,这可能不是编译器用来推断类型的确切顺序。它恰好是我在分析这个答案的类型时所遵循的顺序。)

答案 1 :(得分:2)

(我会打电话给你coin fairCoin)你有:

flipTwoBind' :: Prob Bool
flipTwoBind' = fairCoin     >>=  g   where
   g x       = loadedCoin   >>=  h   where
     h y     = return z              where 
       z     = all (== Tails) [x,y]

从我们得到的(>>=)类型:

fairCoin ::         Prob Coin
(>>=) :: Monad m => m    a    ->  (a -> m b) -> m b       | m ~ Prob, a ~ Coin
                    fairCoin  >>=    g       :: m b       | g :: Coin -> Prob b
flipTwoBind'                              :: Prob Bool    | m ~ Prob, b ~ Bool

以便g :: Coin -> Prob Boolg x :: Prob Bool提供x :: Coin

g x = loadedCoin >>= h以来,我们有

loadedCoin ::       Prob Coin
(>>=) :: Monad m => m    a    ->  (a -> m b) -> m    b 
                  loadedCoin  >>=    h       :: Prob Bool

所以,h :: Coin -> Prob Boolz :: Boolreturn z :: Prob Bool

all ::  (a -> Bool) -> [a] -> Bool
all        p           []  :: Bool

return :: (Monad m) => a -> m a
z      ::           Bool
return                 z :: m Bool           | m ~ Prob so return z :: Prob Bool

由于Prob a本质上是a结果对及其相应概率的标记关联列表,Prob BoolBool结果及其概率的配对列表


使用特定的Prob monadic代码进行翻译,内联所有函数,flipTwoBind'变为

flipTwoBind' = fairCoin     >>=  g
   = flatten (fmap g fairCoin)
   = Prob . convert . getProb $ 
             Prob $ map (\(x,p) -> (g x,p)) $ getProb fairCoin
   = Prob . concat . map (\(x,p) -> map (\(x, y) -> (x, p*y)) $ getProb x)
                  . map (\(x,p) -> (g x,p)) $ getProb fairCoin

(看看ProbgetProb在内部相互取消的情况有多好......)。

切换到基于列表的普通代码(gL xs = getProb (g (Prob xs))fairCoinL = getProb fairCoin等),相当于

   = concat . map (\(x,p) -> map (second (p*)) x)
            . map (\(x,p) -> (gL x,p)) $ fairCoinL
   = concat . map (\(x,p) -> map (second (p*)) $ gL x) $ fairCoinL
   = [(v,p*q) | (x,p) <- fairCoinL, (v,q) <- gL x]
   = ....
   = [(z,r)   | (x,p) <- [(Heads,   1%2),  (Tails,   1%2 )],   -- do a <- fairCoin
                (y,q) <- [(Heads, p*1%10), (Tails, p*9%10)],   --    b <- loadedCoin
                (z,r) <- [(all (== Tails) [x,y],   q*1%1 )] ]  --    return ... all ...
   = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]

当然,上面推导中最后一行之前的那一行同样可以写成

   = [(all (== Tails) [x,y], q)                                -- ... all ... <$>
              | (x,p) <- [(Heads,   1%2),  (Tails,   1%2 )],   --   fairCoin <*>
                (y,q) <- [(Heads, p*1%10), (Tails, p*9%10)] ]  --   loadedCoin

因为(>>= return . f) === fmap f