我正在尝试从Generator
内执行矩阵 - 矩阵乘法。我知道我应该像以前一样使用define_extern
和其他函数,但出于某种原因使用GEMM(即BLAS' mat-mat multiply),我得到了段错误。
这是我的代码:
class TestBLAS : public Halide::Generator<TestBLAS> {
public:
Input<Buffer<float>> A{"A", 2};
Input<Buffer<float>> B{"B", 2};
Output<Buffer<float>> C{"C", 2};
Var x,y;
void generate() {
Func g;
g.define_extern("hblas_sgemm", {false, false, 1.f, A, B, 0.f}, type_of<float>(), 2);
C(x,y) = g(x,y);
}
};
HALIDE_REGISTER_GENERATOR(TestBLAS, testblas)
在某些apps/linear_algebra
脚本中,我发现Halide::Runtime::Buffer::raw_buffer()
已被传递。如何从Halide::GeneratorInput<Halide::Buffer<float> >
甚至Halide::Func
?
我理解Func
的未知界限会让它变得困难,但是可能有单独提供这些信息的方式吗?
从docs我似乎找不到方法......
更新:
按照@ Fabian的回复,我还没有设法从生成器中初始化context
,但是我最接近使代码工作的是:
...
GEMMGenerator<float> gemm;
void generate() {
gemm.set_inputs(1., A, B, 0., B);
C = gemm.result_;
}
编译,但是当我运行它时,我得到:
Condition failed: funcs_.size() == array_size() && exprs_.empty()
Aborted (core dumped)
有什么方法吗?
答案 0 :(得分:0)
查看您的代码我假设您想在generate()
中使用另一个生成器。
然后这应该工作:
auto gen = context.create<GEMMGenerator<float>>();
gen->transpose_A_.set(false); // GeneratorParam
gen->apply(a, A, B, ...); // Inputs
Func res = gen->result_;