为什么我在这种情况下使用Data.Array.Repa.Algorithms.Matrix.mmultP得到类型错误?

时间:2015-11-07 16:36:34

标签: haskell repa

我认为这归结为一种愚蠢的错误,但我无法弄清楚。我有以下代码:

{-# LANGUAGE TypeOperators #-}

import qualified Data.Array.Repa as R
import System.Random
import Data.Random.Normal
import Data.List
import Data.Function 
import qualified Data.Array.Repa.Algorithms.Matrix as M

weights:: (Int,Int)->(Double,Double) -> R.Array R.U (R.Z R.:. Int R.:. Int) Double

weights (nodes,features) range = R.fromListUnboxed (R.Z R.:. nodes R.:. features) values
where 
    values = take (nodes * features) $ randomRs range gen
    gen = mkStdGen 1

powerIteration :: R.Array R.U (R.Z R.:. Int R.:. Int) Double -> R.Array R.U (R.Z R.:. Int R.:. Int) Double

powerIteration m = M.mmultP m rb 
    where 
        rb = weights (rows,1) (0,1)
        (R.Z R.:. _ R.:. rows) = R.extent m


 main = do 
    let matrix = weights (10,3) (0,1)
    print $ powerIteration matrix 

这将创建一个包含10行和3列的2D数组的Repa表示,我想将(使用M.mmultP)的点积与随机数的1D数组(在0,1之间)与列长度​​等于2D数组的行号。我得到类似的东西在ghci中工作,但这可能会产生这个错误;

Couldn't match type ‘R.Array R.U R.DIM2 Double’ with ‘Double’
Expected type: R.Array R.U ((R.Z R.:. Int) R.:. Int) Double
  Actual type: R.Array
                 R.U ((R.Z R.:. Int) R.:. Int) (R.Array R.U R.DIM2 Double)
In the expression: M.mmultP m rb
In an equation for ‘powerIteration’:
    powerIteration m
      = M.mmultP m rb
      where
          rb = weights (rows, 1) (0, 1)
          (R.Z R.:. _ R.:. rows) = R.extent m  

这是一种类型错误,但我似乎无法弄明白。有人可以帮助我吗?

1 个答案:

答案 0 :(得分:0)

是的,评论在哪里正确,IO monad应该在那里:

{-# LANGUAGE TypeOperators #-}

import Data.Array.Repa
import System.Random
import Data.List
import Data.Function 
import Data.Array.Repa.Algorithms.Matrix

weights :: (Int, Int) -> (Double, Double) -> Array U (Z :. Int :. Int) Double
weights (nodes, features) range = fromListUnboxed (Z :. nodes :. features) values
  where 
    values = take (nodes * features) $ randomRs range gen
    gen = mkStdGen 1

powerIteration :: Array U (Z :. Int :. Int) Double -> IO (Array U (Z :. Int :. Int) Double)
powerIteration m = mmultP m rb 
  where 
    rb = weights (rows,1) (0,1)
    (Z :. _ :. rows) = extent m


main = do 
  let matrix = weights (10,3) (0,1)
  iter <- powerIteration matrix
  print iter