如何使以下卤化物代码更有效?

时间:2019-08-01 16:17:33

标签: c++ halide

下面的代码段运行速度比预期的慢。本文的作者http://www.cvlibs.net/publications/Geiger2010ACCV.pdf在118毫秒内计算900x700图像的support_points。我在下面的Halide中实现了他们的算法。 在我的算法中,在长度和宽度上嵌套的for循环在xi和yi上迭代,这是output_x和output_y(先前定义,但未在下面显示)中的点。在嵌套for循环的每次迭代中,都会计算向量top_k并将其推回support_points。 即使为left_buffer.width()== 20和left_buffer.height()== 20计算此管道也需要500毫秒。因此,此实现要慢几个数量级:


...
    int k = 4; // # of support points
    vector<pair<Expr, Expr>> support_points(k * left_buffer.width() * left_buffer.height());
    // Calculate support pixel for each
    Func support("support");
    support(x, y) = Tuple(i32(0), i32(0), f32(0));


    for (int yi = 0; yi < left_buffer.height(); yi++) {
        for (int xi = 0; xi < left_buffer.width() - 2; xi++) {
            bool left = xi < left_buffer.width() / 4;
            bool center = (xi >= left_buffer.width() / 4 && xi < left_buffer.width() * 3 / 4);
            bool right = xi >= left_buffer.width() * 3 / 4;

            vector <pair<Expr, Expr>> scan_range;
            pair <Expr, Expr> scan_height(0, (Expr) left_buffer.height());
            pair <Expr, Expr> scan_width;
            int which_pred = 0;
            if (left) {

                    scan_width = make_pair((Expr) 0, (Expr) left_buffer.width() / 2);
                    which_pred = 0;
            }
            else if (center) {
                    scan_width = make_pair((Expr) xi - left_buffer.width() / 4, (Expr) left_buffer.width() / 2);
                    which_pred = 1;
            }
            else if (right) {
                    scan_width = make_pair((Expr) left_buffer.width() / 2, (Expr) left_buffer.width() / 2);
                    which_pred = 2;
            }
            else {
                cout<<"Error"<<endl;
            }

            scan_range = {scan_width, scan_height};
//            cout<<"xi "<<xi<<endl;
//            cout<<"yi "<<yi<<endl;
//            cout<<"scan_width= "<<scan_width.first<<" "<<scan_width.second<<endl;
//            cout<<"scan_height= "<<scan_height.first<<" "<<scan_height.second<<endl;


            RDom scanner(scan_range);
            Expr predicate[3] = {scanner.x != xi && scanner.y != yi, scanner.x != 0 && scanner.y != 0, scanner.x != xi && scanner.y != yi};
            scanner.where(predicate[which_pred]);
            std::vector<Expr> top_k(k * 3);
            for (int i = 0; i < k; i++) { // say we want top 4 support points.
                top_k[3*i] = 10000.0f;
                top_k[3*i+1] = 0;
                top_k[3*i+2] = 0;
            }

            Func argmin("argmin");
            argmin() = Tuple(top_k);
            Expr next_val = abs(output_x(xi, yi) - output_x(scanner.x, scanner.y)) + abs(output_y(xi, yi) - output_y(scanner.x, scanner.y));
            Expr next_x = scanner.x;
            Expr next_y = scanner.y;

            top_k = Tuple(argmin()).as_vector();
            // Insert a single element into a sorted list without actually branching
            top_k.push_back(next_val);
            top_k.push_back(next_x);
            top_k.push_back(next_y);
            for (int i = k; i > 0; i--) {
                Expr prev_val = top_k[(i-1)*3];
                Expr prev_x = top_k[(i-1)*3 + 1];
                Expr prev_y = top_k[(i-1)*3 + 2];
                Expr should_swap = top_k[i*3] < prev_val;

                top_k[(i-1)*3] = select(should_swap, top_k[i*3], prev_val);
                top_k[(i-1)*3 + 1] = select(should_swap, top_k[i*3 + 1], prev_x);
                top_k[(i-1)*3 + 2] = select(should_swap, top_k[i*3 + 2], prev_y);
                top_k[i*3] = select(should_swap, prev_val, top_k[i*3]);
                top_k[i*3 + 1] = select(should_swap, prev_x, top_k[i*3 + 1]);
                top_k[i*3 + 2] = select(should_swap, prev_y, top_k[i*3 + 2]);
            }
            // Discard the k+1th element
            top_k.pop_back(); top_k.pop_back(); top_k.pop_back();

            bool cond = xi == 10 && yi == 10;
            cout << xi << " "<< yi << " " << cond << endl;

            Expr e = argmin()[0];

            e = print_when(cond, e, "<- argmin() val");
            argmin() = Tuple(top_k);
            argmin.compute_root();
//            argmin.trace_stores();


            argmin.compile_to_lowered_stmt("argmin.html", {}, HTML);
            Realization real = argmin.realize();
            for (int i = 0; i < k; i++) {
                pair<Expr, Expr> c(top_k[3*i+1], top_k[3*i+2]);
                support_points.push_back(c);
            }
        }
    }
    double t2 = current_time();

    cout<<(t2-t1)/100<<" ms"<<endl;
    cout<<"executed"<<endl;
}

如何提高效率?

1 个答案:

答案 0 :(得分:1)

您似乎在程序的各个阶段之间有些困惑。使用Halide,您与ExprsFuncs等一起使用的C ++代码实际上并没有进行任何评估,而是在构造Halide程序,然后可以对其进行编译和运行。这意味着您正在使用的C ++ for循环,std::vectors等都是在Halide程序的程序构建时(基本上是编译时)发生的。您可能会想到C ++模板,它们在编译时进行评估,而它们所构建的C ++代码,则在程序运行时进行评估:就此而言,您在此处编写的C ++代码等效于模板代码。您正在构建的卤化物程序。

这与JIT编译和评估构建该程序的同一C ++程序(realize)中的Halide程序的功能更加混淆。

实际上,我怀疑上面的程序实际上并没有计算您期望的结果。在两次for循环之后,您希望如何处理support_points?您构建的内容包含大量的表达式(代码段),而不是具体的值。而且,您正在JIT编译下,并且每次围绕这些循环(即,每个像素)运行一个新的Halide代码。

如果您现在坚持使用提前编译(compile_to_file或生成器),我认为您可能会更轻松地了解所构建的内容。这使两个阶段(卤化物代码生成时间以及该代码在单独程序中的运行时间)非常不同。