我一直在使用代码处理~9000个文件中的一些数据,因为每个进程都需要一些时间,我设法使用mpi4py来加速它。但我无法使程序捕获ctrl + c信号并保存已经计算的结果。这是代码:
import os
import sys
from string import atof
import gc
import pandas as pd
import numpy as np
from StructFunc.Structure_Function import SF_true, SF_fit_params
import scipy.io as sio
import mpi4py
from mpi4py import MPI
import time
import signal
# def handler(signal_num, frame):
# combine_list = comm.gather(p, root=0)
# if comm_rank == 0:
# print combine_list
# combine_dict = {}
# for sub_dict in combine_list:
# for cur_key in sub_dict.keys():
# combine_dict[cur_key] = sub_dict[cur_key]
# print combine_dict
# sio.savemat('./all_S82_DRW_params.mat', combine_dict)
# print "results before the interrupt has been saved."
# sys.exit()
# signal.signal(signal.SIGINT, handler)
comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()
bad_value = [-1, -99.99, -1]
lc_path = '/Users/zhanghaowen/Desktop/AGN/BroadBand_RM/QSO_S82'
model="DRW"
if comm_rank == 0:
file_list = os.listdir(lc_path)
# print file_list
sys.stderr.write('%d files were to be processed.' %len(file_list))
file_list = comm.bcast(file_list if comm_rank == 0 else None, root=0)
num_files = len(file_list)
# num_files = 6
local_files_offset = np.linspace(0, num_files, comm_size+1).astype('int')
local_files = file_list[local_files_offset[comm_rank] : local_files_offset[comm_rank + 1]]
sys.stderr.write('%d/%d processor gets %d/%d data' %(comm_rank, comm_size, len(local_files), num_files))
cnt = 0
p = {}
for ind, lc in enumerate(local_files):
# beginning of process
try:
print local_files[ind]
lc_file = open(os.path.join(lc_path, local_files[ind]), 'r')
lc_data = pd.read_csv(lc_file, header=None, names=['MJD', 'mag', 'err'], usecols=[0, 1, 2], sep=' ')
lc_file.close()
#remove the bad values in the data
lc_data['MJD'] = lc_data['MJD'].replace(to_replace=bad_value[0], value=np.nan)
lc_data['MJD'] = lc_data['MJD'].replace(to_replace=np.nan, value=np.nanmean(lc_data['MJD']))
lc_data['mag'] = lc_data['mag'].replace(to_replace=bad_value[1], value=np.nan)
lc_data['mag'] = lc_data['mag'].replace(to_replace=np.nan, value=np.nanmean(lc_data['mag']))
lc_data['err'] = lc_data['err'].replace(to_replace=bad_value[2], value=np.nan)
lc_data['err'] = lc_data['err'].replace(to_replace=np.nan, value=np.nanmean(lc_data['err']))
MJD = np.array(lc_data['MJD'])
mag = np.array(lc_data['mag'])
err = np.array(lc_data['err'])
SF_params = []
resamp_tag = 0
while resamp_tag < 1:
sim_err = np.array([abs(np.random.normal(0, err[i], size=1)) for i in range(len(err))]).reshape((1, len(err)))[0]
try:
p[lc] = SF_fit_params(MJD, mag, sim_err, MCMC_step=100, MCMC_threads=1, model=model)
cnt += 1
sys.stderr.write('processor %d has processed %d/%d files \n' %(comm_rank, cnt, len(local_files)))
print "finished the MCMC for %s \n" %lc
resamp_tag += 1
except:
continue
# end of process
except KeyboardInterrupt:
combine_list = comm.gather(p, root=0)
if comm_rank == 0:
print combine_list
combine_dict = {}
for sub_dict in combine_list:
for cur_key in sub_dict.keys():
combine_dict[cur_key] = sub_dict[cur_key]
print combine_dict
sio.savemat('./all_S82_DRW_params.mat', combine_dict)
print "save the dict."
os._exit()
combine_list = comm.gather(p, root=0)
if comm_rank == 0:
print combine_list
combine_dict = {}
for sub_dict in combine_list:
for cur_key in sub_dict.keys():
combine_dict[cur_key] = sub_dict[cur_key]
print combine_dict
sio.savemat('./all_S82_DRW_params.mat', combine_dict)
我已经尝试了两种方法来捕获ctrl + c信号,即我定义的处理程序函数,但是注释了除了KeyboardInterrupt技巧。当我使用python all.py
运行代码然后键入ctrl + c时,脚本将捕获信号但在保存结果后继续运行;当我使用mpirun -np 2 all.py
并输入ctrl + c时,脚本会在不保存的情况下停止。
我认为MPI模式下的问题可能是只有管理器进程捕获了信号,但是为什么在输入ctrl + c后单个进程在非MPI模式下没有停止?有谁知道如何使工作流程捕获信号并在终止之前做一些事情?
顺便说一句,我使用的是openmpi。