Python的Numpy np.random.choice的scala等效项是什么?(scala中的随机加权选择)

时间:2018-11-07 03:28:59

标签: scala random sampling

我在寻找Scala的等效代码或python np.random.choice(Numpy为np)的基础理论。我有一个类似的实现,它使用Python的np.random.choice方法从概率分布中选择随机移动。

Python's code

输入列表:['pooh','兔子','小猪','克里斯托弗']和概率:[0.5、0.1、0.1、0.3]

鉴于每个输入元素的相关概率,我想从输入列表中选择一个值。

1 个答案:

答案 0 :(得分:1)

Scala标准库没有与np.random.choice等效的库,但是根据要模拟的选项/功能,构建自己的库应该不会太困难。

例如,这里是一种获取无限Stream个已提交项目的方法,其中任一项目相对于其他项目加权的可能性。

def weightedSelect[T](input :(T,Int)*): Stream[T] = {
  val items  :Seq[T]    = input.flatMap{x => Seq.fill(x._2)(x._1)}
  def output :Stream[T] = util.Random.shuffle(items).toStream #::: output
  output
}

每个输入项都有一个乘数。因此,要获得对字符cv的无限伪随机选择,其中c出现的时间是3/5的时间,v出现的时间是2/5的时间:

val cvs = weightedSelect(('c',3),('v',2))

因此,np.random.choice(aa_milne_arr,5,p=[0.5,0.1,0.1,0.3])示例的大致等效项是:

weightedSelect("pooh"-> 5
              ,"rabbit" -> 1
              ,"piglet" -> 1
              ,"Christopher" -> 3).take(5).toArray

或者也许您想要一个更好的(伪伪少的)随机分布,它可能会严重偏重。

def weightedSelect[T](items :Seq[T], distribution :Seq[Double]) :Stream[T] = {
  assert(items.length == distribution.length)
  assert(math.abs(1.0 - distribution.sum) < 0.001) // must be at least close

  val dsums  :Seq[Double] = distribution.scanLeft(0.0)(_+_).tail
  val distro :Seq[Double] = dsums.init :+ 1.1 // close a possible gap
  Stream.continually(items(distro.indexWhere(_ > util.Random.nextDouble())))
}

结果仍然是指定元素的无限Stream,但是传入的参数有些不同。

val choices :Stream[String] = weightedSelect( List("this"     , "that")
                                           , Array(4998/5000.0, 2/5000.0))

// let's test the distribution
val (choiceA, choiceB) = choices.take(10000).partition(_ == "this")

choiceA.length  //res0: Int = 9995
choiceB.length  //res1: Int = 5  (not bad)