下面的代码会导致编译错误。
如何使用swift泛型执行类似的操作?还是不可能?
func
S( _ p: UnsafePointer< Float > ) {
}
func
S( _ p: UnsafePointer< Double > ) {
}
func
Sum< N >( _ p: UnsafePointer< N > ) {
S( p )
}
编译错误:
Cannot invoke 'S' with an argument list of type '(UnsafePointer<N>)'
答案 0 :(得分:4)
(请确保也看到Rob Mayoff的答案。虽然这个答案更接近于原始问题并解释了一些重要的概念,但是如果我自己构建,我可能会使用更接近他的东西。他也是记忆力安全,此答案不是,因为我是从原始问题中复制过来的。)
您想从另一个方向执行此操作。基于您的原始代码(该代码不安全地使用baseAddress
,但现在将其保留):
import Accelerate
func sum( _ p: ArraySlice<Float>) -> Float {
return sum(p, summer: vDSP_sve)
}
func sum( _ p: ArraySlice<Double>) -> Double {
return sum(p, summer: vDSP_sveD)
}
typealias Summer<N: Numeric> = (UnsafePointer<N>, vDSP_Stride, UnsafeMutablePointer<N>, vDSP_Length) -> Void
func sum<N: Numeric>(_ p: ArraySlice<N>, summer: Summer<N>) -> N {
var v: N = 0
summer( p.withUnsafeBufferPointer { $0.baseAddress! }
, vDSP_Stride( 1 )
, &v
, vDSP_Length( p.count )
)
return v
}
最不通用的函数(仅适用于Float或Double的那些函数)应调用最通用的函数(对Numeric起作用的那个函数),并传递要更改的部分(加速函数)。
请考虑一下您的代码中的情况,即我在Sum
(也是Int
)列表上调用Numeric
。在这种情况下,它将叫什么“ S”?这就是为什么您从最不普通的人转到最普通的人,而不是相反的原因。
答案 1 :(得分:2)
另一种方法是定义包装vDSP API的Numeric
子协议:
import Accelerate
protocol Acceleratable: Numeric {
static func acceleratedSum(ofElementsStartingAt basePointer: UnsafePointer<Self>, stride: vDSP_Stride, count: vDSP_Length) -> Self
}
extension Float: Acceleratable {
static func acceleratedSum(ofElementsStartingAt basePointer: UnsafePointer<Float>, stride: vDSP_Stride, count: vDSP_Length) -> Float {
var sum: Float = 0
vDSP_sve(basePointer, stride, &sum, count)
return sum
}
}
extension Double: Acceleratable {
static func acceleratedSum(ofElementsStartingAt basePointer: UnsafePointer<Double>, stride: vDSP_Stride, count: vDSP_Length) -> Double {
var sum: Double = 0
vDSP_sveD(basePointer, stride, &sum, count)
return sum
}
}
extension ArraySlice where Element: Acceleratable {
func acceleratedSum() -> Element {
return withUnsafeBufferPointer {
guard let base = $0.baseAddress else { return 0 }
return Element.acceleratedSum(ofElementsStartingAt: base, stride: 1, count: numericCast($0.count))
}
}
}
extension Array where Element: Acceleratable {
func acceleratedSum() -> Element {
return self[...].acceleratedSum()
}
}
let doubles: [Double] = [2, 3, 5]
print(doubles.acceleratedSum())
let floats: [Float] = [3, 1, 4]
print(floats.acceleratedSum())