如何提高大量内核的MPI I / O速度?

时间:2014-04-20 08:24:46

标签: fortran90 hpc mpi-io

我一直在尝试在大量内核上使用MPI I / O运行代码。每个内核读取和写入单个文件所需的时间(所有内核都相同)随着使用的内核数量的增加而增加。我目前正在使用512个内核,这个问题让我的项目变得不可行。但是,即使在8核上运行,问题也会出现;然后大约0.2秒读取文件中的第一个实数。在32个内核上,写一个实数需要30秒以上。我在这里运行它:https://www.msi.umn.edu/hpc/itasca。下面这个简单的代码确实产生了这个问题(文件中的元素数量计算似乎没必要,但在我的实际代码中是必要的):

PROGRAM MAIN

USE MPI
IMPLICIT NONE

! INITIALIZING VARIABLES   

REAL(8) :: A, B
INTEGER :: COUNT_IO, i, j, ST, GO, tag, t, nb_bytes, N, d_each, d_start, d_end, NN
REAL(8) :: time_start, time_end

! VARIABLES RELATED TO MPI

INTEGER :: ierror  ! returns error messages from the mpi subroutines 
INTEGER :: rank    ! identification number of each processor
INTEGER :: nproc   ! number of processors
INTEGER, DIMENSION(mpi_status_size):: status
INTEGER(kind= MPI_OFFSET_KIND ) :: offset
INTEGER :: fh  ! file handle

! EXECUTABLE

    ! INITIALIZE THE MPI ENVIRONMENT

    CALL MPI_INIT(ierror)                           ! initialize MPI 
    CALL MPI_COMM_RANK(MPI_COMM_WORLD,rank,ierror)  ! obtain rank for each node
    CALL MPI_COMM_SIZE(MPI_COMM_WORLD,nproc,ierror) ! obtain the number of nodes
    CALL MPI_TYPE_SIZE(MPI_REAL8,nb_bytes,ierror)

    CALL MPI_FILE_OPEN (MPI_COMM_WORLD,"file.dat",MPI_MODE_RDWR+MPI_MODE_UNIQUE_OPEN,MPI_INFO_NULL,fh,ierror)    

    NN = 2048

    DO d_each=1,NN
        IF (d_each*nproc>=NN) EXIT
    END DO
    d_start = rank*d_each+1 
    d_end   = MIN((rank+1)*d_each,NN)

    DO t = d_start,d_end

        ! READING ONE THREAD AT A TIME

        tag = 1

        GO = 0

        IF (rank .gt. 0) THEN
            CALL MPI_RECV (GO,1,MPI_INTEGER,rank-1,tag, MPI_COMM_WORLD ,status,ierror)
        ENDIF

        time_start = MPI_WTIME()

        i  = 0
        ST = 0
        COUNT_IO = 0

        DO WHILE ((i .lt. 100000) .AND. (ST .eq. 0))
            i = i+1
            offset = nb_bytes*(i-1)
            CALL MPI_FILE_READ_AT (fh,offset,A,1,MPI_REAL8,status,ierror)
            IF (status(1) .eq. 0) THEN
                COUNT_IO = i
                ST = 1
            ELSE
                COUNT_IO = 0
            END IF        
        ENDDO

        N = (COUNT_IO - 1)

        IF (N .gt. 0) THEN

            offset = 0                      
            CALL MPI_FILE_READ_AT (fh,offset,B,1,MPI_REAL8,status,ierror)

        ENDIF

        time_end = MPI_WTIME()

        PRINT *, 'My rank is', rank, 'Time for read  =',time_end-time_start 

        GO = 1    
        IF (rank .lt. nproc-1) THEN
            CALL MPI_SEND (GO,1, MPI_INTEGER ,rank+1,tag, MPI_COMM_WORLD ,ierror)
        ENDIF

        CALL MPI_BARRIER(MPI_COMM_WORLD,ierror)

        ! WRITING ONE THREAD AT A TIME

        tag = 2

        GO = 0

        IF (rank .gt. 0) THEN
            CALL MPI_RECV (GO,1,MPI_INTEGER,rank-1,tag, MPI_COMM_WORLD ,status,ierror)
        ENDIF

        time_start = MPI_WTIME()

        i  = 0
        ST = 0
        COUNT_IO = 0

        DO WHILE ((i .lt. 100000) .AND. (ST .eq. 0))
            i = i+1
            offset = nb_bytes*(i-1)
            CALL MPI_FILE_READ_AT (fh,offset,A,1,MPI_REAL8,status,ierror)
            IF (status(1) .eq. 0) THEN
                COUNT_IO = i
                ST = 1
            ELSE
                COUNT_IO = 0
            END IF        
        ENDDO

        N = (COUNT_IO - 1)

        offset = nb_bytes*N
        CALL MPI_FILE_WRITE_AT (fh,offset,0.0D0,1,MPI_REAL8,status,ierror) 

        time_end = MPI_WTIME()  

        PRINT *, 'My rank is', rank, 'Time for write =',time_end-time_start

        GO = 1    
        IF (rank .lt. nproc-1) THEN
            CALL MPI_SEND (GO,1, MPI_INTEGER ,rank+1,tag, MPI_COMM_WORLD ,ierror)
        ENDIF

        CALL MPI_BARRIER(MPI_COMM_WORLD,ierror)

    ENDDO

    CALL MPI_FILE_CLOSE (fh,ierror)

    CALL MPI_FINALIZE(ierror)

END PROGRAM MAIN

1 个答案:

答案 0 :(得分:1)

这里要认识到的主要是你可以一举读取数据(或者,如果内存是一个问题,就是块 - 但它可以比单个双打更大的块!)并且你不要#39; t需要一次跳到文件末尾一次。

这里有一个示例,它将读取任意块大小的数据,按照您的意愿处理数据,并附加一些数据(在这种情况下,每个人只需将其等级的4个副本添加到结尾文件)。为简单起见,很少有python脚本可以帮助编写和显示测试数据。

$ ./writedata.py 
$ ./readdata.py 
[  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.  14.
  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.]

$ mpirun -np 3 ./usepario
 rank:   0 got data: 0.000...   24.000 
 rank:   1 got data: 0.000...   24.000
 rank:   2 got data: 0.000...   24.000

$ ./readdata.py 
[  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.  14.
  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.   0.   0.   0.   0.   1.
   1.   1.   1.   2.   2.   2.   2.]

usepario.f90:

module pario

contains
    function openFile(filename)
        use mpi
        implicit none
        integer :: openFile, ierr
        character(len=*) :: filename
        integer(MPI_OFFSET_KIND) :: off = 0

        call MPI_File_open(MPI_COMM_WORLD, filename,  &
                           ior(MPI_MODE_RDWR, MPI_MODE_UNIQUE_OPEN),  &
                           MPI_INFO_NULL, openFile, ierr)
        call MPI_File_set_view(openFile, off,  &
                               MPI_DOUBLE_PRECISION, MPI_DOUBLE_PRECISION, &
                               "native", MPI_INFO_NULL, ierr)
    end function  openFile

    subroutine closeFile(fh)
        use mpi
        implicit none
        integer :: fh, ierr
        call MPI_File_close(fh, ierr)
    end subroutine closeFile

    function filesizedoubles(fh)
        use mpi
        implicit none
        integer :: fh, ierr
        integer(MPI_OFFSET_KIND) :: filesize, filesizedoubles
        integer :: dblsize

        call MPI_File_get_size(fh, filesize, ierr)
        call MPI_type_size(MPI_DOUBLE_PRECISION, dblsize, ierr)
        filesizedoubles = filesize / dblsize
    end function filesizedoubles

    subroutine getdatablock(fh, blocksize, datablock, datasize)
        use mpi
        implicit none
        integer :: fh, ierr
        integer :: blocksize, datasize
        double precision, dimension(:) :: datablock
        integer(MPI_OFFSET_KIND) :: fileloc
        integer, dimension(MPI_STATUS_SIZE) :: rstatus

        ! you can also experiment with read_all for non collective/synchronous file
        ! access

        call MPI_File_read(fh, datablock, blocksize, MPI_DOUBLE_PRECISION, &
                           rstatus, ierr)
        call MPI_Get_count(rstatus, MPI_DOUBLE_PRECISION, datasize, ierr)
    end subroutine getdatablock

    subroutine eachappend(fh, filesize, numitems, newdata)
        use mpi
        implicit none
        integer :: fh, numitems
        integer(MPI_OFFSET_KIND) :: filesize
        double precision, dimension(:) :: newdata
        integer :: rank, ierr
        integer(MPI_OFFSET_KIND) :: offset

        call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr)
        offset = filesize + rank*numitems
        call MPI_File_write_at_all(fh, offset, newdata, numitems, &
                                    MPI_DOUBLE_PRECISION,         &
                                    MPI_STATUS_IGNORE, ierr)

    end subroutine eachappend
end module pario


program usepario
    use mpi
    use pario
    implicit none

    integer :: fileh
    integer, parameter :: bufsize=1000, newsize=4
    integer(MPI_OFFSET_KIND) :: filesize
    double precision, allocatable, dimension(:) :: curdata, newdata
    integer :: datasize
    integer :: rank, ierr

    call MPI_Init(ierr)
    call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr)

    allocate(curdata(bufsize))

    fileh = openFile("data.dat")
    filesize = filesizedoubles(fileh)

    do
        call getdatablock(fileh, bufsize, curdata, datasize)
        !! 
        !! process data here
        !!
        !! do i=1,datasize
        !!  ...dostuff...
        !! end do
        !! 
        print '(1X,A,I3,A,F8.3,A,F8.3)', 'rank: ', rank, ' got data: ', curdata(1), '...', curdata(datasize)
        if (datasize /= bufsize) exit
    end do

    deallocate(curdata)

    allocate(newdata(newsize))
    newdata = rank

    call eachappend(fileh, filesize, newsize, newdata)
    call closeFile(fileh)

    call MPI_Finalize(ierr)
end program usepario

writedata.py:

#!/usr/bin/env python

import numpy

numdoubles = 25

data = numpy.arange(numdoubles,dtype=numpy.float64)
data.tofile("data.dat")

readdata.py:

#!/usr/bin/env python

import numpy

data = numpy.fromfile("data.dat",dtype=numpy.float64)
print data