使用Julia 1.0 findmax等效于numpy.argmax

时间:2019-04-09 22:49:42

标签: julia

在Julia中,我想找到矩阵的列索引以获取每一行的最大值,结果为// // Created by user on 6/4/19. // #ifndef CPP_DRAW_HPP #define CPP_DRAW_HPP /** * TODO : Having to switch off other rendering so that only one can go on the display. Find a workaround. * */ #include <GL/glut.h> #include <GLFW/glfw3.h> #include <map> #include <cmath> #include <zconf.h> template<typename T> T *curr_obj_to_draw; template<typename T> class Draw { public: Draw() { this->str_to_fp_mapping["rectangle"] = &(draw_rectangle); this->str_to_fp_mapping["line_loop"] = &(draw_line_loop); this->str_to_fp_mapping["square"] = &(draw_square); this->str_to_fp_mapping["circle"] = &(draw_circle); } static void display() { glClear(GL_COLOR_BUFFER_BIT); glMatrixMode(GL_PROJECTION); glFlush(); } void draw(int argc, char **argv, T *obj) { curr_obj_to_draw<T> = obj; glutInit(&argc, argv); glutInitDisplayMode(GLUT_SINGLE | GLUT_RGB); glutInitWindowSize(1000, 1000); glutInitWindowPosition(200, 20); glutCreateWindow("Drawing shapes"); this->init(); display(); //Check for existence of the name in the map else it can cause errors. glutKeyboardFunc(keyboardCB); std::string str = curr_obj_to_draw<T>->get_name(); glutDisplayFunc(str_to_fp_mapping[curr_obj_to_draw<T>->get_name()]); glutMainLoop(); } void init() { glClearColor(0.0, 0.0, 0.0, 0.0); glLoadIdentity(); gluOrtho2D(-500, 500, -500, 500); } static void draw_rectangle() { glColor3ub(0, 255, 0); glRecti(10, 10, curr_obj_to_draw<T>->get_length() + 100, curr_obj_to_draw<T>->get_breadth() + 100); glFlush(); } static void draw_line_loop() { glBegin(GL_LINE_LOOP); // glVertex2f(10,10); // glVertex2f(curr_obj_to_draw<T>->get_length() + 100, 10); // glVertex2f(curr_obj_to_draw<T>->get_length() + 100, curr_obj_to_draw<T>->get_breadth() + 100); // glVertex2f(10,curr_obj_to_draw<T>->get_breadth() + 100); glEnd(); } static void draw_square() { glColor3ub(0, 255, 255); //glRecti(10, 10, curr_obj_to_draw<T>->get_side() - 100, curr_obj_to_draw<T>->get_side() - 100); glFlush(); } static void draw_triangle() { glColor3ub(0, 255, 0); } static void draw_line() { glColor3ub(0, 255, 0); } static void draw_point() { glColor3ub(0, 255, 0); } static void setPixel(float xc, float yc, float x, float y) { glLineWidth(5); glBegin(GL_LINE_LOOP); glVertex2f(xc, yc); glVertex2f(x, y); glEnd(); } static void draw_circle() { glColor3f(1.0f, .3f, 0); // for (int i = 0; i < 360; i++) { // setPixel(-100, -100, -100 + (curr_obj_to_draw<T>->get_radius() * cos((i * 22.0/7.0) / 180)), // -100 + (curr_obj_to_draw<T>->get_radius() * sin((i * 22.0/7.0) / 180))); // } glFlush(); } static void keyboardCB( unsigned char key, int x, int y ) { //Press escape to close the current window switch ( key ) { case 27: // Escape key glutDestroyWindow ( glutGetWindow() ); exit (0); break; default: std::cout<<"x : " <<x << " y : "<<y<<"\n"; } glutPostRedisplay(); } private: std::map<std::string, void (*)()> str_to_fp_mapping; }; #endif //CPP_DRAW_HPP 。这是我目前的操作方式(Vector{Int}有7列和10,000行):

Samples

这有效,但感觉很笨拙且冗长。想知道是否有更好的方法。

3 个答案:

答案 0 :(得分:2)

更简单:Julia具有argmax函数,而Julia 1.1+具有eachrow迭代器。因此:

map(argmax, eachrow(x))

简单,易读且快速-在我的快速测试中,它与Colin的f3f4的性能相匹配。

答案 1 :(得分:1)

更新:为了完整起见,我向测试套件添加了Matt B.的出色解决方案(并且我还强迫{{1}中的transpose }生成一个新的矩阵,而不是一个惰性视图。

以下是一些不同的方法(您是基本情况f4):

f0

使用f0(x) = [ i[2] for i in findmax(x, dims = 2)[2]][:,1] f1(x) = getindex.(argmax(x, dims=2), 2) f2(x) = [ argmax(vec(x[n,:])) for n = 1:size(x,1) ] f3(x) = [ argmax(vec(view(x, n, :))) for n = 1:size(x,1) ] f4(x) = begin ; xt = Matrix{Float64}(transpose(x)) ; [ argmax(view(xt, :, k)) for k = 1:size(xt,2) ] ; end f5(x) = map(argmax, eachrow(x)) ,我们可以检查每一个的效率(我已经设定好BenchmarkTools):

x = rand(100, 200)

因此,马特(Matt)的方法显然是赢家,因为它似乎只是我的julia> @btime f0($x); 76.846 μs (13 allocations: 4.64 KiB) julia> @btime f1($x); 76.594 μs (11 allocations: 3.75 KiB) julia> @btime f2($x); 53.433 μs (103 allocations: 177.48 KiB) julia> @btime f3($x); 43.477 μs (3 allocations: 944 bytes) julia> @btime f4($x); 73.435 μs (6 allocations: 157.27 KiB) julia> @btime f5($x); 43.900 μs (4 allocations: 960 bytes) 的语法上更简洁的版本(两者可能编译为非常相似的东西,但是我认为检查一下会为时过高)

我希望f3可能有优势,尽管通过实例化f4创建了临时性的,因为它可以在矩阵的列而不是行上运行(Julia是主要列语言,因此对列的操作将总是更快,因为元素在内存中是同步的。但这似乎不足以克服临时程序的缺点。

请注意,如果总是需要完整的transpose,即每行中最大值的行和列索引,那么显然合适的解决方案就是CartesianIndex

答案 2 :(得分:0)

Mapslices函数也是解决此问题的好方法:

julia> Samples = rand(10000, 7);

julia> res = mapslices(row -> findmax(row)[2], Samples, dims=[2])[:,1];

julia> res[1:10]
10-element Array{Int64,1}:
 3
 1
 3
 5
 4
 4
 1
 4
 5
 3

尽管这比Colin上面建议的要慢得多,但对于某些人来说可能更易读。这基本上与您刚开始使用的代码完全相同,但是使用mapslices而不是列表推导。