修复嵌套数组定义,导致“按顺序执行嵌套并行计算...”

时间:2011-11-21 15:50:35

标签: haskell repa

作为一个更大问题的一部分,我试图在数组中定义一个数组,如下所示:

import Data.Array.Repa
type Arr = Array DIM2 Int

arr = force $ fromList (Z :. 5 :. 5) [1..25] :: Arr

combined :: Arr
combined = arr `deepSeqArray` 
    traverse arr (\_ -> Z :. 4 :. 4 :: DIM2) (\f (Z :. x :. y) -> 
        let reg = force $ extract f (x,y) (2,2)
        in  reg `deepSeqArray` sumAll reg)

extract :: (DIM2 -> Int) -> (Int,Int) -> (Int,Int) -> Arr
extract lookup (x0,y0) (width,height) = fromFunction bounds 
  $ \sh -> offset lookup sh
    where 
    bounds = Z :. width :. height
    offset :: (DIM2 -> Int) -> DIM2 -> Int
    offset f (Z :. x :. y) = f (Z :. x + x0 :. y + y0)

main = print combined

extract函数正在使用fromFunction并向其提供查找功能,但它也可以使用traversearr ! ...获得相同的效果。尽管尽早在任何地方使用forcedeepSeqArray,但控制台在此处填充了消息,然后是正确的结果:

  

Data.Array.Repa:按顺序执行嵌套并行计算。     你可能在另一个实例中调用了'force'函数     已经运行。如果第二个版本被暂停,则会发生这种情况     懒惰的评价。使用'deepSeqArray'确保每个数组都是完整的     在你“强迫”下一个之前进行评估。

虽然我没有构建一个带有列表的版本来比较速度,但是在更大的版本中,性能正在受到影响。

这仅仅是嵌套数组定义的结果,因此我应该重构我的程序,以便将内部或外部定义作为列表?我的extract功能是否可怕并且是问题的原因?

提示from this question对于实现这一目标非常有用,但我还没有浏览已编译的代码。

1 个答案:

答案 0 :(得分:2)

这是因为'print'也会隐式强制数组。内部'force'和'sumAll'函数调用并行计算,但是'print',所以你有嵌套的并行性。这个非常明显的事实在Repa 2 API中是一种极大的悲伤。

Repa 3通过导出'force'和'sumAll'等顺序和并行版本解决了这些问题。它还为数组类型添加了一个标记,以指示数组是延迟还是显示。 Repa 3尚未完成,但您可以使用http://code.ouroborus.net/repa上的头版本。在今年晚些时候的GHC 7.4之后,它应该是短暂的。

这是您的示例的Repa 3版本,它运行时没有给出有关嵌套并行性的警告。请注意,'force'现在是'compute'。

import Data.Array.Repa

arr :: Array U DIM2 Int
arr = fromListUnboxed (Z :. 5 :. 5) [1..25]

combined :: Array U DIM2 Int
combined 
  = computeP $ traverse arr (\_ -> Z :. 4 :. 4 :: DIM2) 
  $ \f (Z :. x :. y) -> sumAllS $ extract f (x,y) (2,2)

extract :: (DIM2 -> Int) -> (Int,Int) -> (Int,Int) -> Array D DIM2 Int
extract lookup (x0,y0) (width,height) 
  = fromFunction bounds 
  $ \sh -> offset lookup sh
    where 
    bounds = Z :. width :. height
    offset :: (DIM2 -> Int) -> DIM2 -> Int
    offset f (Z :. x :. y) = f (Z :. x + x0 :. y + y0)

main = print combined