我如何专注于特质功能?

时间:2015-11-22 12:49:41

标签: rust

我有BLAS功能的特性:

pub trait Blas {
    fn gemv<F>(&self, trans: Transpose,
               cols: usize, rows: usize, matrix: &[F], matrix_factor: F,
               vector: &[F], vector_inc: usize, vector_factor: F,
               result: &[F], result_inc: usize) -> Result<(), Error>;
    ...
}

现在我想创建一个实现这个特性的类型:

pub struct CudaBlas {
    ...
}

impl Blas for CudaBlas {
    ...
}

问题是我需要gemv<f32>gemv<f64>的单独专业化:每个专业都应该调用专用的共享库函数。没有编译器投诉,没有成功表达。我怎样才能做到这一点?

更新

我尝试了method proposed by Jonas Tepe并且它似乎无法正常工作。这是纯化的例子:

trait Trait<T> {
    fn func(&self, arg: T);
}

struct Struct {
    field: usize,
}

impl Trait<f32> for Struct {
    fn func(&self, arg: f32) {
        println!("32bits: {}", arg);    
    }
}

impl Trait<f64> for Struct {
    fn func(&self, arg: f64) {
        println!("64bits: {}", arg);
    }
}

struct Struct2<T> {
    field2: T,
}

// yes, I plan to use my CudaBlas inside some generic NeuralNet<T>
impl<T> Struct2<T> {
    fn func2(&self, arg: T) {
        let s = Struct{field: 1};
        s.func(arg);
    }
}

fn main() {
    let s32 = Struct2::<f32>{field2: 1f32};
    let s64 = Struct2::<f64>{field2: 2f64};
    s32.func2(1f32);
    s64.func2(1f64);
}

我明白了:

  

错误:未对类型Trait<T> [E0277]

实施特性Struct

使Struct成为通用的并不能解决问题(编译器抱怨找不到类型func的{​​{1}})。只是惊讶于Rust仿制药的限制性。

1 个答案:

答案 0 :(得分:3)

一种解决方案是使您的特征Blas相对于浮点类型具有通用性,然后为CudaBlas struct提供此特征的两个单独实现:

pub trait Blas<F> {
    fn gemv(&self, trans: Transpose,
               cols: usize, rows: usize, matrix: &[F], matrix_factor: F,
               vector: &[F], vector_inc: usize, vector_factor: F,
               result: &[F], result_inc: usize) -> Result<(), Error>;
    ...
}

impl Blas<f32> for CudaBlas {
    fn gemv(&self, trans: Transpose,
            cols: usize, rows: usize, matrix: &[f32], matrix_factor: f32,
            vector: &[f32], vector_inc: usize, vector_factor: f32,
            result: &[f32], result_inc: usize) -> Result<(), Error> {
           // implement f32 specific functionality
     }
}

impl Blas<f64> for CudaBlas {
        fn gemv(&self, trans: Transpose,
                cols: usize, rows: usize, matrix: &[f64], matrix_factor: f64,
                vector: &[f64], vector_inc: usize, vector_factor: f64,
                result: &[f64], result_inc: usize) -> Result<(), Error> {
               // implement f64 specific functionality
         }

}

之后,您可以使用gemv()CudaBlas每次使用所需的特定类型结果调用f32上的方法f64