在scala中坚持期货

时间:2015-05-26 15:35:07

标签: scala future

基本上我在cassandra上运行两个期货查询,然后我需要做一些计算并返回值(平均值)。

这是我的代码:

object TestWrapFuture {
  def main(args: Array[String]) {
    val category = 5392
    ExtensiveComputation.average(category).onComplete {
      case Success(s) => println(s)
      case Failure(f) => throw new Exception(f)
    }
  }
}

class ExtensiveComputation {

  val volume = new ListBuffer[Int]()

  def average(categoryId: Int): Future[Double] = {

    val productsByCategory = Product.findProductsByCategory(categoryId)

    productsByCategory.map { prods =>
      for (prod <- prods if prod._2) {
        Sku.findSkusByProductId(prod._1).map { skus =>
          skus.foreach(sku => volume += (sku.height.get * sku.width.get * sku.length.get))
        }
      }

      val average = volume.sum / volume.length
      average
    }
  }
}

object ExtensiveComputation extends ExtensiveComputation

那么问题是什么?

skus.foreach将结果值附加到ListBuffer中。由于一切都是异步的,当我试图在我的主要部分获得结果时,我得到一个错误,说我不能除以零。

实际上,由于我的Sku.findSkusByProduct返回Future,当我尝试计算平均值时,音量为空。

我应该在此计算之前阻止任何事情,还是应该做其他事情?

修改

好吧,我试图阻止这样:

  val volume = new ListBuffer[Int]()

  def average(categoryId: Int): Future[Double] = {

    val productsByCategory = Product.findProductsByCategory(categoryId)

    val blocked = productsByCategory.map { prods =>
      for (prod <- prods if prod._2) {
        Sku.findSkusByProductId(prod._1).map { skus =>
          skus.foreach(sku => volume += (sku.height.get * sku.width.get * sku.length.get))
        }
      }
    }

    Await.result(blocked, Duration.Inf)
    val average = volume.sum / volume.length
    Future.successful(average)
  }

然后我从这段代码中得到了两个不同的结果:

    Sku.findSkusByProductId(prod._1).map { skus =>
          skus.foreach(sku => volume += (sku.height.get * sku.width.get * sku.length.get))
        }

1 - 如果只有50个人可以在cassandra上查找,它只会运行并给我结果

2 - 当有很多像1000这样的时候,它给了我

  

java.lang.ArithmeticException:/ by zero

编辑2

我尝试了这个代码,因为@Olivier Michallat提议

   def average(categoryId: Int): Future[Double] = {

    val productsByCategory = Product.findProductsByCategory(categoryId)

    productsByCategory.map { prods =>
      for (prod <- prods if prod._2) findBlocking(prod._1)
      volume.sum / volume.length
    }
  }

  def findBlocking(productId: Long) = {
    val future = Sku.findSkusByProductId(productId).map { skus =>
      skus.foreach(sku => volume += (sku.height.get * sku.width.get * sku.length.get))
    }

    Await.result(future, Duration.Inf)
  }

以下为@kolmar提议:

   def average(categoryId: Int): Future[Int] = {
    for {
      prods <- Product.findProductsByCategory(categoryId)
      filtered = prods.filter(_._2)
      skus <- Future.traverse(filtered)(p => Sku.findSkusByProductId(p._1))
    } yield {
      val volumes = skus.flatten.map(sku => sku.height.get * sku.width.get * sku.length.get)
      volumes.sum / volumes.size
    }
  }

两个skus都可以找到50个,但是两个都失败了,很多skus找到了1000个抛出ArithmeticException:/ by zero

似乎在返回未来之前无法计算所有内容......

2 个答案:

答案 0 :(得分:3)

在计算平均值之前,您需要等到findSkusByProductId生成的所有期货都已完成。因此,在Seq中积累所有这些未来,在其上调用Future.sequence以获得Future[Seq],然后将该未来映射到计算平均值的函数。然后将productsByCategory.map替换为flatMap

答案 1 :(得分:2)

由于您必须在一系列参数上调用返回Future的函数,因此最好使用Future.traverse

例如:

object ExtensiveComputation {
  def average(categoryId: Int): Future[Double] = {
    for {
      products <- Product.findProductsByCategory(categoryId)
      filtered = products.filter(_._2)
      skus <- Future.traverse(filtered)(p => Sku.findSkusByProductId(p._1))
    } yield {
      val volumes = skus.map { sku => 
        sku.height.get * sku.width.get * sku.length.get }
      volumes.sum / volumes.size
    }
  }
}