Learn You a Haskell显示Prob
newtype:
newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving Show
以下Prob
的定义:
instance Functor Prob where
fmap f (Prob xs) = Prob $ map (\(x,p) -> (f x,p)) xs
instance Monad Prob where
return x = Prob [(x, 1%1)]
p >>= f = flatten (fmap f p)
然后是支持功能:
flatten :: Prob (Prob a) -> Prob a
flatten = Prob . convert . getProb
convert :: [(Prob a, Rational)] -> [(a, Rational)]
convert = concat . (map f)
f :: (Prob a, Rational) -> [(a, Rational)]
f (p, r) = map (mult r) (getProb p)
mult :: Rational -> (a, Rational) -> (a, Rational)
mult r (x, y) = (x, r*y)
我写了flatten,convert,f和mult函数,所以我对它们很满意。
然后我们将>>=
应用于以下示例,涉及数据类型Coin
:
data Coin = Heads | Tails deriving (Show, Eq)
coin :: Prob Coin
coin = Prob [(Heads, 1%2), (Tails, 1%2)]
loadedCoin :: Prob Coin
loadedCoin = Prob [(Heads, 1%10), (Tails, 9%10)]
LYAH说,If we throw all the coins at once, what are the odds of all of them landing tails?
flipTwo:: Prob Bool
flipTwo= do
a <- coin -- a has type `Coin`
b <- loadedCoin -- similarly
return (all (== Tails) [a,b])
致电flipTwo
返回:
Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}
flipTwo
可以使用>>=
重写:
flipTwoBind' :: Prob Bool
flipTwoBind' = coin >>=
\x -> loadedCoin >>=
\y -> return (all (== Tails) [x,y])
我不理解return (all (== Tails) [x,y])
的类型。由于它是>>=
的右侧,因此其类型必须为a -> m b
(其中Monad m
)。
我的理解是(all (==Tails) [x,y])
会返回True or False
,但return
如何导致上述结果:
Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}
?
答案 0 :(得分:4)
请注意>>=
运算符的RHS是lambda表达式,不 return
的应用程序:
\y -> return (all (== Tails) [x,y])
此lambda具有预期的类型(Monad m) => a -> m b
。
让我们从底部构建类型:
正如您所说,all (== Tails) [x,y]
会返回True
或False
。换句话说,它的类型是Bool
。
现在,检查ghci中return
的类型,我们看到:
Prelude> :t return
return :: Monad m => a -> m a
所以return (all (==Tails) [x,y])
是Monad m => m Boolean
类型。
将其包含在lambda中,然后给出类型(Monad m) => a -> m Boolean
。
(请注意,在此过程中,编译器会推断出具体monad类型为Prob
。)
您应该将return
视为采用常规值并将其包装为Monad
。
<强>增加:强>
让我们分析一下
的类型flipTwoBind' = coin >>=
\x -> loadedCoin >>=
\y -> return (all (== Tails) [x,y])
我们首先注意到这里最外层的表达式是(>>=)
的应用程序,其类型为:
Prelude> :t (>>=)
(>>=) :: Monad m => m a -> (a -> m b) -> m b
LHS为coin
,其类型为Prob Coin
,因此我们立即推断m
为Prob
且a
为Coin
。这意味着RHS必须具有某种类型Coin -> Prob b
的类型b
。现在让我们来看看RHS:
\x -> loadedCoin >>= \y -> return (all (== Tails) [x,y])
这里我们有一个lambda,它返回(>>=)
的应用程序的结果,所以lambda有类型
(Monad m) => a -> m b
这与第一个(>>=)
的应用的预期类型相匹配,因此a
此处为Coin
,m
为Prob
。
现在分析(>>=)
的内部应用,我们看到它的类型被推断为
(>>=) :: Prob Coin -> (Prob -> Prob b) -> Prob b
我们已经分析了第二个(>>=)
的RHS,因此b
被推断为Bool
。
(注意,这可能不是编译器用来推断类型的确切顺序。它恰好是我在分析这个答案的类型时所遵循的顺序。)
答案 1 :(得分:2)
(我会打电话给你coin
fairCoin
)你有:
flipTwoBind' :: Prob Bool
flipTwoBind' = fairCoin >>= g where
g x = loadedCoin >>= h where
h y = return z where
z = all (== Tails) [x,y]
从我们得到的(>>=)
类型:
fairCoin :: Prob Coin
(>>=) :: Monad m => m a -> (a -> m b) -> m b | m ~ Prob, a ~ Coin
fairCoin >>= g :: m b | g :: Coin -> Prob b
flipTwoBind' :: Prob Bool | m ~ Prob, b ~ Bool
以便g :: Coin -> Prob Bool
和g x :: Prob Bool
提供x :: Coin
。
自g x = loadedCoin >>= h
以来,我们有
loadedCoin :: Prob Coin
(>>=) :: Monad m => m a -> (a -> m b) -> m b
loadedCoin >>= h :: Prob Bool
所以,h :: Coin -> Prob Bool
,z :: Bool
和return z :: Prob Bool
:
all :: (a -> Bool) -> [a] -> Bool
all p [] :: Bool
return :: (Monad m) => a -> m a
z :: Bool
return z :: m Bool | m ~ Prob so return z :: Prob Bool
由于Prob a
本质上是a
结果对及其相应概率的标记关联列表,Prob Bool
是Bool
结果及其概率的配对列表
使用特定的Prob
monadic代码进行翻译,内联所有函数,flipTwoBind'
变为
flipTwoBind' = fairCoin >>= g
= flatten (fmap g fairCoin)
= Prob . convert . getProb $
Prob $ map (\(x,p) -> (g x,p)) $ getProb fairCoin
= Prob . concat . map (\(x,p) -> map (\(x, y) -> (x, p*y)) $ getProb x)
. map (\(x,p) -> (g x,p)) $ getProb fairCoin
(看看Prob
和getProb
在内部相互取消的情况有多好......)。
切换到基于列表的普通代码(gL xs = getProb (g (Prob xs))
和fairCoinL = getProb fairCoin
等),相当于
= concat . map (\(x,p) -> map (second (p*)) x)
. map (\(x,p) -> (gL x,p)) $ fairCoinL
= concat . map (\(x,p) -> map (second (p*)) $ gL x) $ fairCoinL
= [(v,p*q) | (x,p) <- fairCoinL, (v,q) <- gL x]
= ....
= [(z,r) | (x,p) <- [(Heads, 1%2), (Tails, 1%2 )], -- do a <- fairCoin
(y,q) <- [(Heads, p*1%10), (Tails, p*9%10)], -- b <- loadedCoin
(z,r) <- [(all (== Tails) [x,y], q*1%1 )] ] -- return ... all ...
= [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]
当然,上面推导中最后一行之前的那一行同样可以写成
= [(all (== Tails) [x,y], q) -- ... all ... <$>
| (x,p) <- [(Heads, 1%2), (Tails, 1%2 )], -- fairCoin <*>
(y,q) <- [(Heads, p*1%10), (Tails, p*9%10)] ] -- loadedCoin
因为(>>= return . f) === fmap f
。