在Haskell中从列表中无需替换的随机样本的更好方法

时间:2012-12-08 16:56:30

标签: list haskell

我需要从较长的列表中随机取样而不替换(每个元素仅在样本中出现一次)。我正在使用下面的代码,但现在我想知道:

  1. 是否有库函数可以执行此操作?
  2. 如何改进此代码? (我是一个Haskell初学者,所以即使有一个库函数,这也很有用。)
  3. 抽样的目的是能够概括从样本分析到人群的结果。

    import System.Random
    
    -- | Take a random sample without replacement of size size from a list.
    takeRandomSample :: Int -> Int -> [a] -> [a]
    takeRandomSample seed size xs
        | size < hi  = subset xs rs
        | otherwise = error "Sample size must be smaller than population."
        where
            rs = randomSample seed size lo hi
            lo = 0
            hi = length xs - 1
    
    getOneRandomV g lo hi = randomR (lo, hi) g
    
    rsHelper size lo hi g x acc
        | x `notElem` acc && length acc < size = rsHelper size lo hi new_g new_x (x:acc)
        | x `elem` acc && length acc < size = rsHelper size lo hi new_g new_x acc
        | otherwise = acc
        where (new_x, new_g) = getOneRandomV g lo hi
    
    -- | Get a random sample without replacement of size size between lo and hi.
    randomSample seed size lo hi = rsHelper size lo hi g x [] where
    (x, g)  = getOneRandomV (mkStdGen seed) lo hi
    
    subset l = map (l !!) 
    

2 个答案:

答案 0 :(得分:6)

以下是Daniel Fischer在评论中建议的快速“背后”实现,使用我首选的PRNG(mwc-random):

{-# LANGUAGE BangPatterns #-}

module Sample (sample) where

import Control.Monad.Primitive
import Data.Foldable (toList)
import qualified Data.Sequence as Seq
import System.Random.MWC

sample :: PrimMonad m => [a] -> Int -> Gen (PrimState m) -> m [a]
sample ys size = go 0 (l - 1) (Seq.fromList ys) where
    l = length ys
    go !n !i xs g | n >= size = return $! (toList . Seq.drop (l - size)) xs
                  | otherwise = do
                      j <- uniformR (0, i) g
                      let toI  = xs `Seq.index` j
                          toJ  = xs `Seq.index` i
                          next = (Seq.update i toI . Seq.update j toJ) xs
                      go (n + 1) (i - 1) next g
{-# INLINE sample #-}

这几乎是(简洁)功能性重写R的内部C版sample(),因为它被称为无需替换。

sample只是一个递归工作函数的包装器,它会逐渐改变填充,直到达到所需的样本大小,只返回那么多的混洗元素。编写这样的函数可以确保GHC可以内联它。

它易于使用:

*Main> create >>= sample [1..100] 10
[51,94,58,3,91,70,19,65,24,53]

生产版本可能希望使用类似可变向量而不是Data.Sequence的内容,以减少执行GC所花费的时间。

答案 1 :(得分:2)

我认为执行此操作的标准方法是使用前N个元素保持固定大小的缓冲区,并且对于每个第i个元素,i&gt; = N,执行此操作:

  1. 在0和i之间选择一个随机数j,
  2. 如果j < N然后用当前的元素替换缓冲区中的第j个元素。
  3. 您可以通过归纳证明正确性:

    如果您只有N个元素,这显然会生成随机样本(我假设顺序无关紧要)。现在假设它属于第i个元素。这意味着任何元素在缓冲区中的概率为N /(i + 1)(我从0开始计数)。

    在选择随机数后,第i + 1个元素在缓冲区中的概率为N /(i + 2)(j在0和i + 1之间,其中N个最终在缓冲区中)。其他人怎么样?

    P(k'th element is in the buffer after processing the i+1'th) =
    P(k'th element was in the buffer before)*P(k'th element is not replaced) =
    N/(i+1) * (1-1/(i+2)) =
    N/(i+2)
    

    以下是使用标准(慢速)System.Random在样本大小空间中执行此操作的一些代码。

    import Control.Monad (when)                                                                                                       
    import Data.Array                                                                                                                 
    import Data.Array.ST                                                                                                              
    import System.Random (RandomGen, randomR)                                                                                         
    
    sample :: RandomGen g => g -> Int -> [Int] -> [Int]                                                                               
    sample g size xs =                                                                                                                
      if size < length xs                                                                                                             
      then error "sample size must be >= input length"                                                                                
      else elems $ runSTArray $ do                                                                                                    
        arr <- newListArray (0, size-1) pre                                                                                         
        loop arr g size post                                                                                                          
      where                                                                                                                           
        (pre, post) = splitAt size xs                                                                                                 
        loop arr g i [] = return arr                                                                                                  
        loop arr g i (x:xt) = do                                                                                                      
          let (j, g') = randomR (0, i) g                                                                                              
          when (j < size) $ writeArray arr j x                                                                                        
          loop arr g' (i+1) xt