我有一个问题,简化时:
以下是示例代码示例:
from numpy.random import uniform
from time import sleep
def userfunction(x):
# do something complicated
# but computation always takes takes roughly the same time
sleep(1) # comment this out if too slow
xnew = uniform() # in reality, a non-trivial function of x
y = -0.5 * xnew**2
return xnew, y
x0, cur = userfunction([])
x = [x0] # a sequence of points
while cur < -2e-16:
# this should be parallelised
# search for a new point higher than a threshold
x1, next = userfunction(x)
if next <= cur:
# throw away (this branch is taken 99% of the time)
pass
else:
cur = next
print cur
x.append(x1) # note that userfunction depends on x
print x
我希望并行化(例如跨群集),但问题是我需要在找到成功点后终止其他工作者,或者至少告知他们新的x(如果他们设法获得如果使用较旧的x,新阈值以上,结果仍然可以接受)。只要没有成功,我就需要工人重复。
我正在寻找可以处理这类问题的工具/框架,用任何科学编程语言(C,C ++,Python,Julia等,请不要Fortran)。
可以用优雅的MPI解决这个问题吗?我不明白如何使用MPI通知/中断/更新工作人员。
更新:添加代码注释,说大多数尝试都不成功,并且不影响变量userfunction取决于。
答案 0 :(得分:0)
如果userfunction()
不花太长时间,那么这里有一个符合“MPI半优雅”的选项
为了保持简单,让我们假设等级0只是一个协调器并且不计算任何东西。
排名0
cur = 0
x = []
while cur < -2e-16:
MPI_Recv(buf=cur+x1, src=MPI_ANY_SOURCE)
x.append(x1)
MPI_Ibcast(buf=cur+x, root=0, request=req)
MPI_Wait(request=req)
排名!= 0
x0, cur = userfunction([])
x = [x0] # a sequence of points
while cur < -2e-16:
MPI_Ibcast(buf=newcur+newx, root=0, request=req
# search for a new point higher than a threshold
x1, next = userfunction(x)
if next <= cur:
# throw away (this branch is taken 99% of the time)
MPI_Test(request=ret, flag=found)
if found:
MPI_Wait(request)
else:
cur = next
MPI_Send(buffer=cur+x1, dest=0)
MPI_Wait(request)
需要额外的逻辑才能正确处理 - 等级0也进行计算 - 几个等级同时找到解决方案,后续消息必须由等级0消耗
严格来说,当在其他任务上找到解决方案时,任务不会“中断”。相反,每个任务都会定期检查其他任务是否找到了解决方案。所以在某个地方发现解决方案和所有任务都停止寻找解决方案之间存在延迟,但如果userfunction()
没有“太长时间”,这对我来说是非常可以接受的。
答案 1 :(得分:0)
我用以下代码大致解决了这个问题。
此时仅传输curmax,但可以使用第二个广播+标签发送另一个阵列。
import numpy
import time
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
import logging
logging.basicConfig(filename='mpitest%d.log' % rank,level=logging.DEBUG)
logFormatter = logging.Formatter("[%(name)s %(levelname)s]: %(message)s")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
consoleHandler.setLevel(logging.INFO)
logging.getLogger().addHandler(consoleHandler)
log = logging.getLogger(__name__)
if rank == 0:
curmax = numpy.random.random()
seq = [curmax]
log.info('%d broadcasting starting value %f...' % (rank, curmax))
comm.Ibcast(numpy.array([curmax]))
was_updated = False
while True:
# check if news available
status = MPI.Status()
a_avail = comm.iprobe(source=MPI.ANY_SOURCE, tag=12, status=status)
if a_avail:
sugg = comm.recv(source=status.Get_source(), tag=12)
log.info('%d received new limit from %d: %s' % (rank, status.Get_source(), sugg))
if sugg < curmax:
curmax = sugg
seq.append(curmax)
log.info('%d updating to %s' % (rank, curmax))
was_updated = True
else:
# ignore
pass
# check if next message is already waiting:
if comm.iprobe(source=MPI.ANY_SOURCE, tag=12):
# consume it first before broadcasting outdated info
continue
if was_updated:
log.info('%d broadcasting new limit %f...' % (rank, curmax))
comm.Ibcast(numpy.array([curmax]))
was_updated = False
else:
# no message waiting for us and no broadcast done, so pause
time.sleep(0.1)
print
print data, rank
else:
log.info('%d waiting for root to send us starting value...' % (rank))
nextmax = numpy.empty(1, dtype=float)
comm.Ibcast(nextmax).Wait()
amax = float(nextmax)
numpy.random.seed(rank)
update_req = comm.Ibcast(nextmax)
while True:
a = numpy.random.uniform()
if a < amax:
log.info('%d found new: %s, sending to root' % (rank, a))
amax = a
comm.isend(a, dest=0, tag=12)
s = update_req.Get_status()
#log.info('%d bcast status: %s' % (rank, s))
if s:
update_req.Wait()
log.info('%d receiving new limit from root, %s' % (rank, nextmax))
amax = float(nextmax)
update_req = comm.Ibcast(nextmax)