如何使用Talos并行化GridSearch扫描

时间:2019-05-07 02:41:27

标签: python talos

talos支持GPU并行化时,如何扩展扫描对象以支持CPU + GPU并行化?

1 个答案:

答案 0 :(得分:0)

采用将扫描实验分为多个过程的方法:

import multiprocessing as mp
from itertools import product
import talos
import os

# Helper function to create configuration chunks
def chunkify(lst, n):
    return [lst[i::n] for i in range(n)]

# a Talos Scan Configuration superset
playbook_configurations = {
    "input_lstm_dim": [5, 15, 30, 50],
    "dense_a_dim": [None, 5],
    "dense_b_dim": [None, 5],
    "dense_c_dim": [None, 5],
    "dropout_a_rate": [None, 0.7, 0.5, 0.3],
    "epochs": [100],
    "verbose": [verbose_flag],
    "batch_normalization": [None, 1]
}

# Threadsafe Queue for scan results
output = mp.Queue()

# Actual scan to run within each process
def process_scan(playbook_scan_settings, output):
    scan = talos.Scan(
        ...
        params=playbook_scan_settings,
    )
    ...
    output.put(results) # pump results onto queue

# Sample Process count based on core affinity
cpu_count = len(os.sched_getaffinity(0))

# Cartesian product of Talos Configuration
playbook_configurations_cartesian_product = [dict(zip(playbook_configurations, v)) for v in product(
    *playbook_configurations.values())]

# Configuration chunks to assign to each process
playbook_configuration_groups = chunkify(
    playbook_configurations_cartesian_product, cpu_count)

processes = []
for playbook_configuration_group in playbook_configuration_groups:
    # merged (array) configuration for process group
    playbook_scan_settings = {}
    for g in playbook_configuration_group:
        for k, v in g.items():
            if not k in playbook_scan_settings:
                playbook_scan_settings[k] = []
            if not v in playbook_scan_settings[k]:
                playbook_scan_settings[k].append(v)
    if bool(playbook_scan_settings):
        # process to scan on merged configuration for process group
        processes.append(mp.Process(
            target=process_scan, args=(playbook_scan_settings, output)))

for p in processes:
    p.start()
for p in processes:
    p.join()

# Will be the result from the message queue
results = [output.get() for p in processes]

您可以轻松地将Report对象,获胜模型和每个扫描层的指标输入到消息队列中以进行最终选择。