我有BLAS功能的特性:
pub trait Blas {
fn gemv<F>(&self, trans: Transpose,
cols: usize, rows: usize, matrix: &[F], matrix_factor: F,
vector: &[F], vector_inc: usize, vector_factor: F,
result: &[F], result_inc: usize) -> Result<(), Error>;
...
}
现在我想创建一个实现这个特性的类型:
pub struct CudaBlas {
...
}
impl Blas for CudaBlas {
...
}
问题是我需要gemv<f32>
和gemv<f64>
的单独专业化:每个专业都应该调用专用的共享库函数。没有编译器投诉,没有成功表达。我怎样才能做到这一点?
更新
我尝试了method proposed by Jonas Tepe并且它似乎无法正常工作。这是纯化的例子:
trait Trait<T> {
fn func(&self, arg: T);
}
struct Struct {
field: usize,
}
impl Trait<f32> for Struct {
fn func(&self, arg: f32) {
println!("32bits: {}", arg);
}
}
impl Trait<f64> for Struct {
fn func(&self, arg: f64) {
println!("64bits: {}", arg);
}
}
struct Struct2<T> {
field2: T,
}
// yes, I plan to use my CudaBlas inside some generic NeuralNet<T>
impl<T> Struct2<T> {
fn func2(&self, arg: T) {
let s = Struct{field: 1};
s.func(arg);
}
}
fn main() {
let s32 = Struct2::<f32>{field2: 1f32};
let s64 = Struct2::<f64>{field2: 2f64};
s32.func2(1f32);
s64.func2(1f64);
}
我明白了:
错误:未对类型
实施特性Trait<T>
[E0277]Struct
使Struct
成为通用的并不能解决问题(编译器抱怨找不到类型func
的{{1}})。只是惊讶于Rust仿制药的限制性。
答案 0 :(得分:3)
一种解决方案是使您的特征Blas
相对于浮点类型具有通用性,然后为CudaBlas struct
提供此特征的两个单独实现:
pub trait Blas<F> {
fn gemv(&self, trans: Transpose,
cols: usize, rows: usize, matrix: &[F], matrix_factor: F,
vector: &[F], vector_inc: usize, vector_factor: F,
result: &[F], result_inc: usize) -> Result<(), Error>;
...
}
impl Blas<f32> for CudaBlas {
fn gemv(&self, trans: Transpose,
cols: usize, rows: usize, matrix: &[f32], matrix_factor: f32,
vector: &[f32], vector_inc: usize, vector_factor: f32,
result: &[f32], result_inc: usize) -> Result<(), Error> {
// implement f32 specific functionality
}
}
impl Blas<f64> for CudaBlas {
fn gemv(&self, trans: Transpose,
cols: usize, rows: usize, matrix: &[f64], matrix_factor: f64,
vector: &[f64], vector_inc: usize, vector_factor: f64,
result: &[f64], result_inc: usize) -> Result<(), Error> {
// implement f64 specific functionality
}
}
之后,您可以使用gemv()
或CudaBlas
每次使用所需的特定类型结果调用f32
上的方法f64
。