快速MLMultiArray的有效平方和距离

时间:2020-10-08 12:27:29

标签: swift coreml

尝试对MLMultiArray实现通用距离扩展,以便我可以清除自定义KNN分类器。

这是我为避免递归或循环中上下文切换而提出的实现。除了丑陋之外,它还应该做好自己的工作。


extension MLMultiArray {
    func distanceFrom(otherArray:MLMultiArray) throws -> Float  {
        var d:Float = 0.0
        let shape = self.shape as! [Int]
        switch shape.count {
        case 1:
            for i in 0..<shape[0] {
                d += pow(self[i].floatValue - otherArray[i].floatValue,2)
            }
        case 2:
            for i in 0..<shape[0] {
                for j in 0..<shape[1] {
                    let index = [i,j] as [NSNumber]
                    let di:Float = (self[index].floatValue - otherArray[index].floatValue)
                    d += pow(di,2)
                }
            }
        case 3:
            for i in 0..<shape[0] {
                for j in 0..<shape[1] {
                    for k in 0..<shape[2] {
                        let index = [i,j,k] as [NSNumber]
                        let di:Float = (self[index].floatValue - otherArray[index].floatValue)
                        d += pow(di,2)
                    }
                }
            }
        default:
            throw fatalError("not implemented")
        }
        d = pow(d, 0.5)
        return d
    }
}

有更聪明的方法吗?我习惯了numpy的多路复用,从来没有做过任何快速的数值运算,所以我可能忽略了显而易见的内容。

0 个答案:

没有答案