Kotlin用泛型计算标准偏差

时间:2017-12-17 19:18:30

标签: generics math kotlin

我想写一个方法来计算所提供数字的标准偏差。

这就是它的样子:

fun calculateSD(numArray: List<Int>): Double {
var sum = 0.0
var standardDeviation = 0.0

for (num in numArray) {
    sum += num
}

val mean = sum / numArray.size

for (num in numArray) {
    standardDeviation += Math.pow(num - mean, 2.0)
}

val divider = numArray.size - 1

return Math.sqrt(standardDeviation / divider)
}

但是,我希望这也适用于Double,Float等列表。

所以需要使用泛型:

fun <T>calculateSD(numArray: List<T>): Double {
var sum = 0.0
var standardDeviation = 0.0

for (num in numArray) {
    sum += num
}

val mean = sum / numArray.size

for (num in numArray) {
    standardDeviation += Math.pow(num - mean, 2.0)
}

val divider = numArray.size - 1

return Math.sqrt(standardDeviation / divider)
}

但是当我尝试这样做时,我得到Android Studio IDE错误,它会强调'+ ='运算符和' - '运算符。

我怎样才能让它发挥作用?

2 个答案:

答案 0 :(得分:3)

没有很好的方法可以按照您想象的方式对不同的数字类型进行泛化; IntDouble等仅延伸NumberComparable,两者均未定义operator plus

但是,在您的特定用例中,您可以利用累加器变量始终为Double的事实:

fun <T : Number> calculateSD(numArray: List<T>): Double {
    //   ^^^^^^

    // ... code code code ...

    for (num in numArray) {
        sum += num.toDouble()  // This *is* available via Number interface
    }

    // ... code code code ...
}

FWIW,你可以摆脱明确的循环:

val sum = numArray.sumByDouble { it.toDouble() }

或者因此:

val sum = numArray
    .map { it.toDouble() }
    .sum()

答案 1 :(得分:0)

正如@Oliver Charlesworth 已经回答的那样,你想要的是不可能的。但是,您可以更简洁地编写代码

private fun IntArray.std(): Double {
    val std = this.fold(0.0) { a, b -> a + (b-this.average()).pow(2) }
    return Math.sqrt(std / 10)
}

private fun FloatArray.std(): Double {
    val std = this.fold(0.0) { a, b -> a + (b-this.average()).pow(2) }
    return Math.sqrt(std / 10)
}

private fun DoubleArray.std(): Double {
    val std = this.fold(0.0) { a, b -> a + (b-this.average()).pow(2) }
    return Math.sqrt(std / 10)
}

您可以通过调用使用它:

val numArray1 = intArrayOf(6, 1, 1, 0,  1,  9,  0,  9,  6,  6)
val SD1 = numArray1.std()
val numArray2 = floatArrayOf(0.6f, 1.0f, 1.0f,  0.8f,   1.0f,   0.9f,   0.0f,   0.9f,   0.6f,   0.6f)
val SD2 = numArray2.std()
val numArray3 = doubleArrayOf(0.6, 1.0, 1.0,    0.8,    1.0,    0.9,    0.0,    0.9,    0.6,    0.6)
val SD3 = numArray3.std()

println("Standard Deviation = $SD1")
println("Standard Deviation = $SD2")
println("Standard Deviation = $SD3")