是否可以将列表中的项乘以常量?

时间:2018-01-20 08:46:21

标签: scala

考虑以下代码(现在不正确):

def sum_of_products(weights: Seq[Double], points: Seq[Seq[Double]]) : Seq[Double] = {
  weights.zip(points).map((weight, point) => weight * point).sum
}

我的pointSeq[Double]。顺便说一句,我希望将point乘以其关联的weight(“关联”,因为两者都存在于同一对中。)

那么,如何将Seq[Double]乘以Double?我想我不能使用map,因为它会返回Seq[Double]。因此,以下代码似乎不正确:

def sum_of_products(weights: Seq[Double], points: Seq[Seq[Double]]) : Seq[Double] = {
  weights.zip(points).map((weight, point) => point.map(coordinate => weight * coordinate)).sum
}

确实,我在这里将(weight, point)转换为Seq[Double]。但我想将其转换为Double

我正在进行线性插值。

解决方案(即:n维度中线性插值分量的实现):

def sum_of_products(weights: Seq[Double], points: Seq[Seq[Double]]) : Seq[Double] = {
  weights.zip(points).map(
    weight_point => weight_point._2.map(coordinate => weight_point._1 * coordinate)
  ).reduce((point_a : Seq[Double], point_b : Seq[Double]) => point_a.zip(point_b).map(coordinate_points => coordinate_points._1 + coordinate_points._2))
}

2 个答案:

答案 0 :(得分:1)

在您的代码中:

weights.zip(points).map((weight, point) ...

map从2元组到某种类型T采用函数,但是你的代码传递给map一个带2个输入的函数,这个函数不同于2元组。这是函数的正确实现:

def sum_of_products(weights: Seq[Double], points: Seq[Seq[Double]]) : Seq[Double] = {
    weights.zip(points).map { 
         case (weight, point) =>
             point.map(coordinate => weight * coordinate).sum 
    }
}

关键是在地图中添加单词case,以便您可以使用模式匹配从2元组中提取weightpoint。这在功能上等同于:

def sum_of_products(weights: Seq[Double], points: Seq[Seq[Double]]) : Seq[Double] = {
    weights.zip(points).map { 
         wp =>
             wp._2.map(coordinate => wp._1 * coordinate).sum 
    }
}

答案 1 :(得分:1)

用例语法扩展带花括号的地图中的元组

weights.zip(poDoubles).map{case(weight, poDouble) => weight * poDouble}

现在,如果你将一个Double的序列乘以一个Double,你最终会得到一个Double的Seq,因为序列的每个元素都会乘以给定的double。你可以这样做:

返回Seq [Seq [Double]]并相应地操作结果

def sumOfProducts(weights: Seq[Double], points: Seq[Seq[Double]]): Seq[Seq[Double]] =
  weights.zip(points).map{case (weight, point) => point.map(coordinate => weight * coordinate)}

或者使用平面图来平整你的序列并返回Seq [Double]

Seq(Seq(1.0,2.0,3.0),Seq(1.0,2.0,3.0))将是Seq(1.0,2.0,3.0,1.0,2.0,3.0)

def sumOfProducts(weights: Seq[Double], points: Seq[Seq[Double]]) : Seq[Double] =
  weights.zip(points).flatMap{case (weight, point) => point.map(coordinate => weight * coordinate)}