使用多处理报告OSError加速DEAP

时间:2019-07-31 08:56:35

标签: python multiprocessing deap

我想使用多处理来加速DEAP,但始终会出现OSError。这是我的代码的缩写版本:

import operator
import math
import random
import numpy as np
import pandas as pd    
from deap import algorithms
from deap import base
from deap import creator
from deap import tools
from deap import gp
import multiprocessing

# protectedDiv
def protectedDiv(left, right):
    try:
        return left / right
    except ZeroDivisionError:
        return 1

# omitting some other functions

creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMax)

# here is DEAP strong typed GP setting
pset = gp.PrimitiveSetTyped("MAIN", [np.ndarray] * 12, np.ndarray)
pset.addPrimitive(operator.add, [np.ndarray, np.ndarray], np.ndarray)
pset.addPrimitive(operator.sub, [np.ndarray, np.ndarray], np.ndarray)
pset.renameArguments(ARG0='close')
pset.renameArguments(ARG1='open')

# here is fitness function. My goal is maximum stock return's ICIR.
def evalSymbReg(individual):
    # omitting code
    return icir,

toolbox = base.Toolbox()
toolbox.register("expr", gp.genHalfAndHalf, pset=pset, min_=1, max_=3)
toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("compile", gp.compile, pset=pset)

toolbox.register("evaluate", evalSymbReg)
toolbox.register("select", tools.selTournament, tournsize=10)
toolbox.register("mate", gp.cxOnePoint)
toolbox.register("expr_mut", gp.genFull, min_=0, max_=2)
toolbox.register("mutUniform", gp.mutUniform, expr=toolbox.expr_mut, pset=pset)
toolbox.decorate("mate", gp.staticLimit(key=operator.attrgetter("height"), max_value=10))
toolbox.decorate("mutUniform", gp.staticLimit(key=operator.attrgetter("height"), max_value=10))

def main():
    n_sample = 5000
    n_gen = 40
    cxpb = 0.6
    mutUniformpb = 0.4

    pop = toolbox.population(n=n_sample)
    hof = tools.HallOfFame(10)

    stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
    stats_size = tools.Statistics(len)
    mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)
    mstats.register("avg", np.nanmean)
    mstats.register("min", np.nanmin)
    mstats.register("max", np.nanmax)

    pop, log = algorithms.my_eaSimple(pop, toolbox, cxpb, mutUniformpb, mutNodeReplacementpb, mutEphemeralpb, mutShrinkpb,
                                      n_gen, stats=mstats, halloffame=hof, verbose=True)

    # print log
    return pop, log, hof, info, top10

# here is my data file.
df = pd.read_csv(r'C:\Users\xxyao\research\国债期货\data\data_summary.csv')
df['pct-1'] = df['close'].pct_change().shift(-1)
df['month'] = [x[0:7] for x in df['date']]

if __name__ == "__main__":
    pool = multiprocessing.Pool(processes=6)
    toolbox.register('map', pool.map)
    pop, log, hof, info, top10 = main()

运行代码时,我收到如下错误消息:

enter image description here

此消息反复出现在窗口中很快。我不知道哪里错了。正如DEAP文档所述,我保护Pool()中的__name__ == __main__。但是它仍然无法工作。有人可以帮我吗。

1 个答案:

答案 0 :(得分:0)

将此代码放在main中,该功能将起作用。

pool = multiprocessing.Pool(processes=6)
toolbox.register('map', pool.map)