如何解释Python类中的@expand_dims?

时间:2018-01-24 06:10:15

标签: python numpy

我是Python的新手,我正在尝试使用函数“one_hot_to_label_batch”,该函数可以在此website的第115行找到。

然而,在此函数的正上方有一个“@expand_dims”。这是我第一次遇到这个。我知道“expand_dims”在Numpy范围内,但我不知道为什么在此处将其定义为“@expand_dims”。任何澄清将不胜感激。

1 个答案:

答案 0 :(得分:2)

这是装饰者。请参阅file的顶部:

from .util import expand_dims

从这一行,我们可以告诉装饰器是在同一目录中的util.py文件中定义的。我们在util.py文件中搜索,您可以找到以下函数:

def expand_dims(func):
    def wrapper(self, batch):
        ndim = batch.ndim
        if ndim == 3:
            batch = np.expand_dims(batch, axis=0)
        batch = func(self, batch)
        if ndim == 3:
            batch = np.squeeze(batch, axis=0)
        return batch
    return wrapper

上述功能取自this source

如果您不熟悉,请先检查装饰员。