我的包含numba的程序以及标准numba示例未在nvidia-smi界面上显示。可能是什么问题?

时间:2018-07-12 14:20:35

标签: python gpu python-multiprocessing jit numba

此功能正在加速中,速度提高了40倍,但nvidia-smi界面中未反映GPU的使用情况。我正在使用spyder运行程序。使用top显示40个python进程,这是由于使用了具有4个块大小和40个代理的多重处理。使用ls -lrt可以正确正确地生成文件。 Spyder使用的是python 2.7。 Numba版本是0.38.0,而spyder版本是最新的。

@nb.njit(fastmath=True, parallel=True, error_model="numpy", nogil = True)
def heavywork(x, y, z, a, b):
    e = 0
    X = Y = Z = 128
    for x2 in range(x - 1, x + 5):
        for y2 in range(y - 1, y + 5):
            for z2 in range(z - 1, z + 5):
                if (-1 < x < X and
                    -1 < y < Y and
                    -1 < z < Z and
                    (x != x2 or y != y2 or z != z2) and
                    (0 <= x2 < X) and
                    (0 <= y2 < Y)and
                    (0 <= z2 < Z)):
                    q = a[x2][y2][z2]
                    di = np.sqrt((x - x2) ** 2 + (y - y2) ** 2 + (z - z2) ** 2) * 1.2
                    if di <= 6 and di>= 2:
                        e = 4
                    elif di > 6 and di < 8:
                        e = 38 * di - 224
                    elif di >= 8:
                        e = 80
                    else:
                        continue
                    value = q / (e * di)
                    c[x][y][z] = c[x][y][z] + value


    return c

示例-

来自Numba网站本身,用于测试多线程。我尝试运行更多的示例numba加速示例,例如冒泡排序,向量加法等,但结果相同。

# -*- coding: utf-8 -*-
"""
Example of multithreading by releasing the GIL through ctypes.
"""
from __future__ import print_function, division, absolute_import

from timeit import repeat
import threading
from ctypes import pythonapi, c_void_p
from math import exp

import numpy as np
from numba import jit, void, double

nthreads = 32
size = 1e6

def timefunc(correct, s, func, *args, **kwargs):
    print(s.ljust(20), end=" ")
    # Make sure the function is compiled before we start the benchmark
    res = func(*args, **kwargs)
    if correct is not None:
        assert np.allclose(res, correct)
    # time it
    print('{:>5.0f} ms'.format(min(repeat(lambda: func(*args, **kwargs),
                                      number=5, repeat=2)) * 1000))
    return res

def make_singlethread(inner_func):
    def func(*args):
        length = len(args[0])
        result = np.empty(length, dtype=np.float64)
        inner_func(result, *args)
        return result
    return func

def make_multithread(inner_func, numthreads):
    def func_mt(*args):
        length = len(args[0])
        result = np.empty(length, dtype=np.float64)
        args = (result,) + args        
        chunklen = (length + 1) // numthreads
        chunks = [[arg[i * chunklen:(i + 1) * chunklen] for arg in args]
                  for i in range(numthreads)]

        # You should make sure inner_func is compiled at this point, because
        # the compilation must happen on the main thread. This is the case
        # in this example because we use jit().
        threads = [threading.Thread(target=inner_func, args=chunk)
                   for chunk in chunks[:-1]]
        for thread in threads:
            thread.start()

        # the main thread handles the last chunk
        inner_func(*chunks[-1])

        for thread in threads:
            thread.join()
        return result
    return func_mt

savethread = pythonapi.PyEval_SaveThread
savethread.argtypes = []
savethread.restype = c_void_p

restorethread = pythonapi.PyEval_RestoreThread
restorethread.argtypes = [c_void_p]
restorethread.restype = None

def inner_func(result, a, b):
    threadstate = savethread()
    for i in range(len(result)):
        result[i] = exp(2.1 * a[i] + 3.2 * b[i])
    restorethread(threadstate)

signature = void(double[:], double[:], double[:])
inner_func_nb = jit(signature, nopython=True)(inner_func)
func_nb = make_singlethread(inner_func_nb)
func_nb_mt = make_multithread(inner_func_nb, nthreads)

def func_np(a, b):
    return np.exp(2.1 * a + 3.2 * b)

a = np.random.rand(size)
b = np.random.rand(size)
c = np.random.rand(size)

correct = timefunc(None, "numpy (1 thread)", func_np, a, b)
timefunc(correct, "numba (1 thread)", func_nb, a, b)
timefunc(correct, "numba (%d threads)" % nthreads, func_nb_mt, a, b)

0 个答案:

没有答案