Python 2.7和Cython:未定义全局名称col2im_6d_cython

时间:2017-04-15 11:31:13

标签: python cython

我在Ubuntu 16.04上遇到python 2.7和Cython的问题。 我正在尝试从cs231n课程(卷积神经网络)运行代码。 但是唯一的函数 col2im_6d_cython 不起作用。错误是:

NameError: global name 'col2im_6d_cython' is not defined

函数 col2im_6d_cython im2col_cython.pyx中定义:

def col2im_6d_cython(np.ndarray[DTYPE_t, ndim=6] cols, int N, int C, int H, int W,
        int HH, int WW, int pad, int stride):
    cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
    cdef int out_h = (H + 2 * pad - HH) / stride + 1
    cdef int out_w = (W + 2 * pad - WW) / stride + 1
    cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * pad, W + 2 * pad),
                                                         dtype=cols.dtype)

    col2im_6d_cython_inner(cols, x_padded, N, C, H, W, HH, WW, out_h, out_w, pad, stride)

    if pad > 0:
        return x_padded[:, :, pad:-pad, pad:-pad]
    return x_padded 

调用 col2im_6d_cython 的文件是 fast_layers.py

from cs231n.im2col_cython import col2im_cython, im2col_cython
from cs231n.im2col_cython import col2im_6d_cython

def conv_backward_strides(dout, cache):
        x, w, b, conv_param, x_cols = cache
    stride, pad = conv_param['stride'], conv_param['pad']

    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    _, _, out_h, out_w = dout.shape

    db = np.sum(dout, axis=(0, 2, 3))

    dout_reshaped = dout.transpose(1, 0, 2, 3).reshape(F, -1)
    dw = dout_reshaped.dot(x_cols.T).reshape(w.shape)

    dx_cols = w.reshape(F, -1).T.dot(dout_reshaped)
    dx_cols.shape = (C, HH, WW, N, out_h, out_w)
    dx = col2im_6d_cython(dx_cols, N, C, H, W, HH, WW, pad, stride)

    return dx, dw, db

col2im_cython im2col_cython 正常工作,但只有 col2im_6d_cython 不起作用。

在我看来,Cython安装存在问题。我已经通过运行来安装它: python setup.py build_ext --inplace

setup.py是:

from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
import numpy

extensions = [
Extension('im2col_cython', ['im2col_cython.pyx'],
        include_dirs = [numpy.get_include()]
),
]

setup(
    ext_modules = cythonize(extensions),
)

我在安装Cython时发出警告:

Warning: Extension name 'im2col_cython' does not match fully qualified name 'cs231n.im2col_cython' of 'im2col_cython.pyx'
running build_ext

为什么只有 col2im_6d_cython 不起作用?有没有办法解决它?

提前谢谢大家!

1 个答案:

答案 0 :(得分:0)

通过完全卸载Anaconda3并安装Anaconda2解决了这个问题。然后我创建了新环境并重新安装了所有需要的软件包。现在没有出现错误