Haskell无法推断出类型相等

时间:2013-02-06 17:04:31

标签: haskell automatic-differentiation

我有以下代码,无法编译:

  import Numeric.AD

  data Trainable a b = forall n . Floating n =>  Trainable ([n] -> a -> b) (a -> b -> [n] -> n) 

  trainSgdFull :: (Floating n, Ord n) => Trainable a b -> [n] -> a -> b -> [[n]]
  trainSgdFull (Trainable _ cost) init input target =  gradientDescent (cost input target) init

我想使用Trainable类型来表示可通过梯度下降训练的机器学习系统。第一个算法是传递函数,sencond是成本函数,a是输入类型,b是输出/目标类型,列表包含可学习参数。 编译器抱怨这个:

 src/MachineLearning/Training.hs:12:73:
Could not deduce (n1 ~ ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n)
from the context (Floating n, Ord n)
  bound by the type signature for
             trainSgdFull :: (Floating n, Ord n) =>
                             Trainable a b -> [n] -> a -> b -> [[n]]
  at src/MachineLearning/Training.hs:12:3-95
or from (Floating n1)
  bound by a pattern with constructor
             Trainable :: forall a b n.
                          Floating n =>
                          ([n] -> a -> b) -> (a -> b -> [n] -> n) -> Trainable a b,
           in an equation for `trainSgdFull'
  at src/MachineLearning/Training.hs:12:17-32
or from (Numeric.AD.Internal.Classes.Mode s)
  bound by a type expected by the context:
             Numeric.AD.Internal.Classes.Mode s =>
             [ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n]
             -> ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n
  at src/MachineLearning/Training.hs:12:56-95
  `n1' is a rigid type variable bound by
       a pattern with constructor
         Trainable :: forall a b n.
                      Floating n =>
                      ([n] -> a -> b) -> (a -> b -> [n] -> n) -> Trainable a b,
       in an equation for `trainSgdFull'
       at src/MachineLearning/Training.hs:12:17
Expected type: [ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n1]
               -> ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n1
  Actual type: [n] -> n
In the return type of a call of `cost'
In the first argument of `gradientDescent', namely
  `(cost input target)'

基本概念是否合适?如果是,我怎么能编译代码?

1 个答案:

答案 0 :(得分:6)

问题在于

data Trainable a b = forall n . Floating n =>  Trainable ([n] -> a -> b) (a -> b -> [n] -> n)

表示在

Trainable transfer cost

使用的n类型丢失了。所有已知的是,某些类型Guessme带有Floating实例,以便

transfer :: [Guessme] -> a -> b
cost :: a -> b -> [Guessme] -> Guessme

您可以使用仅适用于Trainable的功能或仅适用于Complex Float的功能构建Double,或者......

但是在

trainSgdFull :: (Floating n, Ord n) => Trainable a b -> [n] -> a -> b -> [[n]]
trainSgdFull (Trainable _ cost) init input target =  gradientDescent (cost input target) init

您正在尝试将cost用于提供Floating类型作为参数。

构建Trainable以使用类型n0,用户提供类型n1,这些可能相同或不同。因此编译器无法推断出它们是相同的。

如果您不想让n成为Trainable的类型参数,则需要使其包含多态函数,这些函数与每个 Floating一起使用输入来电者用品

data Trainable a b
    = Trainable (forall n. Floating n => [n] -> a -> b)
                (forall n. Floating n => a -> b -> [n] -> n)

(需要Rank2Types,或者,因为它正在被弃用,RankNTypes)。