如何在连续内存中的Unboxed Vector中存储Haskell数据类型

时间:2014-04-05 14:34:49

标签: haskell

我想存储一个非参数的,解压缩的数据类型,如

data Point3D = Point3D {-# UNPACK #-} !Int {-# UNPACK #-} !Int {-# UNPACK #-} !Int

在未装箱的矢量中。 Data.Vector.Unboxed说:

  

特别是,未装箱的对矢量表示为未装箱的矢量对。

为什么?我希望我的Point3D在内存中一个接一个地布局,以便在顺序迭代它们时获得快速缓存本地访问 - 相当于C中的mystruct[1000] < / p>

使用Vector.Unboxed或其他方式,我该如何实现?


顺便说一下:vector-th-unbox同样发生,因为您只需将数据类型转换为(Unbox a, Unbox b) => Unbox (a, b) instance

1 个答案:

答案 0 :(得分:9)

我不知道为什么对的向量存储为向量对,但您可以轻松地为数据类型编写实例以按顺序存储元素。

{-# LANGUAGE TypeFamilies, MultiParamTypeClasses #-}

import qualified Data.Vector.Generic as G 
import qualified Data.Vector.Generic.Mutable as M 
import Control.Monad (liftM, zipWithM_)
import Data.Vector.Unboxed.Base

data Point3D = Point3D {-# UNPACK #-} !Int {-# UNPACK #-} !Int {-# UNPACK #-} !Int

newtype instance MVector s Point3D = MV_Point3D (MVector s Int)
newtype instance Vector    Point3D = V_Point3D  (Vector    Int)
instance Unbox Point3D

此时最后一行将导致错误,因为Point3D的矢量类型没有实例。它们可以写成如下:

instance M.MVector MVector Point3D where 
  basicLength (MV_Point3D v) = M.basicLength v `div` 3 
  basicUnsafeSlice a b (MV_Point3D v) = MV_Point3D $ M.basicUnsafeSlice (a*3) (b*3) v 
  basicOverlaps (MV_Point3D v0) (MV_Point3D v1) = M.basicOverlaps v0 v1 
  basicUnsafeNew n = liftM MV_Point3D (M.basicUnsafeNew (3*n))
  basicUnsafeRead (MV_Point3D v) n = do 
    [a,b,c] <- mapM (M.basicUnsafeRead v) [3*n,3*n+1,3*n+2]
    return $ Point3D a b c 
  basicUnsafeWrite (MV_Point3D v) n (Point3D a b c) = zipWithM_ (M.basicUnsafeWrite v) [3*n,3*n+1,3*n+2] [a,b,c]

instance G.Vector Vector Point3D where 
  basicUnsafeFreeze (MV_Point3D v) = liftM V_Point3D (G.basicUnsafeFreeze v)
  basicUnsafeThaw (V_Point3D v) = liftM MV_Point3D (G.basicUnsafeThaw v)
  basicLength (V_Point3D v) = G.basicLength v `div` 3
  basicUnsafeSlice a b (V_Point3D v) = V_Point3D $ G.basicUnsafeSlice (a*3) (b*3) v 
  basicUnsafeIndexM (V_Point3D v) n = do 
    [a,b,c] <- mapM (G.basicUnsafeIndexM v) [3*n,3*n+1,3*n+2]
    return $ Point3D a b c 

我认为大多数功能定义都是自解释的。点向量存储为Int s的向量,n点为3n3n+13n+2 Int s 。