运行mpi4py

时间:2017-06-02 01:39:07

标签: python openmpi mpi4py

我一直在使用代码处理~9000个文件中的一些数据,因为每个进程都需要一些时间,我设法使用mpi4py来加速它。但我无法使程序捕获ctrl + c信号并保存已经计算的结果。这是代码:

import os
import sys
from string import atof
import gc
import pandas as pd
import numpy as np
from StructFunc.Structure_Function import SF_true, SF_fit_params
import scipy.io as sio
import mpi4py
from mpi4py import MPI
import time
import signal



# def handler(signal_num, frame):

#   combine_list = comm.gather(p, root=0)

#   if comm_rank == 0:
#       print combine_list
#       combine_dict = {}
#       for sub_dict in combine_list:
#           for cur_key in sub_dict.keys():
#               combine_dict[cur_key] = sub_dict[cur_key]
#   print combine_dict
#   sio.savemat('./all_S82_DRW_params.mat', combine_dict)
#   print "results before the interrupt has been saved."
#   sys.exit()


# signal.signal(signal.SIGINT, handler)


comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()


bad_value = [-1, -99.99, -1]

lc_path = '/Users/zhanghaowen/Desktop/AGN/BroadBand_RM/QSO_S82'

model="DRW"



if comm_rank == 0:

    file_list = os.listdir(lc_path)
    # print file_list
    sys.stderr.write('%d files were to be processed.' %len(file_list))


file_list = comm.bcast(file_list if comm_rank == 0 else None, root=0)
num_files = len(file_list)
# num_files = 6
local_files_offset = np.linspace(0, num_files, comm_size+1).astype('int')
local_files = file_list[local_files_offset[comm_rank] : local_files_offset[comm_rank + 1]]
sys.stderr.write('%d/%d processor gets %d/%d data' %(comm_rank, comm_size, len(local_files), num_files))
cnt = 0

p = {}

for ind, lc in enumerate(local_files):
    # beginning of process
    try:
        print local_files[ind]
        lc_file = open(os.path.join(lc_path, local_files[ind]), 'r')
        lc_data = pd.read_csv(lc_file, header=None, names=['MJD', 'mag', 'err'], usecols=[0, 1, 2], sep=' ')
        lc_file.close()

        #remove the bad values in the data

        lc_data['MJD'] = lc_data['MJD'].replace(to_replace=bad_value[0], value=np.nan)
        lc_data['MJD'] = lc_data['MJD'].replace(to_replace=np.nan, value=np.nanmean(lc_data['MJD']))

        lc_data['mag'] = lc_data['mag'].replace(to_replace=bad_value[1], value=np.nan)
        lc_data['mag'] = lc_data['mag'].replace(to_replace=np.nan, value=np.nanmean(lc_data['mag']))

        lc_data['err'] = lc_data['err'].replace(to_replace=bad_value[2], value=np.nan)
        lc_data['err'] = lc_data['err'].replace(to_replace=np.nan, value=np.nanmean(lc_data['err']))

        MJD = np.array(lc_data['MJD'])
        mag = np.array(lc_data['mag'])
        err = np.array(lc_data['err'])

        SF_params = []

        resamp_tag = 0
        while resamp_tag < 1:
            sim_err = np.array([abs(np.random.normal(0, err[i], size=1)) for i in range(len(err))]).reshape((1, len(err)))[0]

            try:
                p[lc] = SF_fit_params(MJD, mag, sim_err, MCMC_step=100, MCMC_threads=1, model=model)
                cnt += 1
                sys.stderr.write('processor %d has processed %d/%d files \n' %(comm_rank, cnt, len(local_files)))
                print "finished the MCMC for %s \n" %lc
                resamp_tag += 1
            except:
                continue
    # end of process


    except KeyboardInterrupt:
        combine_list = comm.gather(p, root=0)

        if comm_rank == 0:
            print combine_list
            combine_dict = {}
            for sub_dict in combine_list:
                for cur_key in sub_dict.keys():
                    combine_dict[cur_key] = sub_dict[cur_key]
            print combine_dict

        sio.savemat('./all_S82_DRW_params.mat', combine_dict)
        print "save the dict."
        os._exit()



combine_list = comm.gather(p, root=0)

if comm_rank == 0:
    print combine_list
    combine_dict = {}
    for sub_dict in combine_list:
        for cur_key in sub_dict.keys():
            combine_dict[cur_key] = sub_dict[cur_key]
    print combine_dict

    sio.savemat('./all_S82_DRW_params.mat', combine_dict)

我已经尝试了两种方法来捕获ctrl + c信号,即我定义的处理程序函数,但是注释了除了KeyboardInterrupt技巧。当我使用python all.py运行代码然后键入ctrl + c时,脚本将捕获信号但在保存结果后继续运行;当我使用mpirun -np 2 all.py并输入ctrl + c时,脚本会在不保存的情况下停止。

我认为MPI模式下的问题可能是只有管理器进程捕获了信号,但是为什么在输入ctrl + c后单个进程在非MPI模式下没有停止?有谁知道如何使工作流程捕获信号并在终止之前做一些事情?

顺便说一句,我使用的是openmpi。

0 个答案:

没有答案