Halide:从发生器内运行BLAS操作

时间:2017-11-16 12:49:37

标签: blas halide

我正在尝试从Generator内执行矩阵 - 矩阵乘法。我知道我应该像以前一样使用define_extern和其他函数,但出于某种原因使用GEMM(即BLAS' mat-mat multiply),我得到了段错误。

这是我的代码:

class TestBLAS : public Halide::Generator<TestBLAS> {                                           
public:                                                                                                                 
    Input<Buffer<float>> A{"A", 2};                                                             
    Input<Buffer<float>> B{"B", 2};                                                             
    Output<Buffer<float>> C{"C", 2};                                                            

    Var x,y;                                                                                    

    void generate() {                                                                           
        Func g;                                                                                 
        g.define_extern("hblas_sgemm", {false, false, 1.f, A, B, 0.f}, type_of<float>(), 2);

        C(x,y) = g(x,y);                                                                        
    }                                                                                           
};                                                                                              

HALIDE_REGISTER_GENERATOR(TestBLAS, testblas)                                                   

在某些apps/linear_algebra脚本中,我发现Halide::Runtime::Buffer::raw_buffer()已被传递。如何从Halide::GeneratorInput<Halide::Buffer<float> >甚至Halide::Func

访问此指针

我理解Func的未知界限会让它变得困难,但是可能有单独提供这些信息的方式吗?

docs我似乎找不到方法......

更新:

按照@ Fabian的回复,我还没有设法从生成器中初始化context,但是我最接近使代码工作的是:

...
GEMMGenerator<float> gemm;
void generate() { 
    gemm.set_inputs(1., A, B, 0., B);    
    C = gemm.result_;
}

编译,但是当我运行它时,我得到:

Condition failed: funcs_.size() == array_size() && exprs_.empty()
Aborted (core dumped)

有什么方法吗?

1 个答案:

答案 0 :(得分:0)

查看您的代码我假设您想在generate()中使用另一个生成器。 然后这应该工作:

auto gen = context.create<GEMMGenerator<float>>();
gen->transpose_A_.set(false); // GeneratorParam
gen->apply(a, A, B, ...); // Inputs
Func res = gen->result_;