如何使matplotlib处理自定义类“单位”

时间:2019-05-05 18:42:03

标签: python numpy matplotlib multidimensional-array

我正试图使我的课程与matplotlib's units兼容,并面临意外的行为。

这是我的自定义类的简化版本,它不是numpy的ndarray的子​​类:

import numpy as np
import matplotlib
import matplotlib.units as units

class Toto:
    def __init__(self, value_like, unit_like):
        self.value_like = value_like # typically a scalar or array
        self.unit_like = unit_like # a string describing the unit

    def __array__(self, *args, **kwargs):
        return np.array(self.value_like, *args, **kwargs)

# To test if plot as expected, without units handling
arr_x = Toto(np.arange(5), "meter")
arr_y = Toto(np.arange(5), "second")
plt.plot(arr_x, arr_y)

注意,我添加了一个__array__方法以使其与matplotlib成为“可绘制的”(如果不是,当numpy尝试使用TypeError: float() argument must be a string or a number, not 'Toto'转换Toto时,我会遇到一个array(toto_instance, float)异常) 。我怀疑我的问题实际上是由这种方法引起的,但是我不知道为什么/如何做。无论如何,继续解决实际问题:

现在,我跟随example in the doc为我的Toto类创建了一个转换界面:

class TotoConverter(units.ConversionInterface):

    @staticmethod
    def convert(value, unit, axis):
        'Convert a toto object value to a scalar or array'
        old_toto_unit = axis.get_unit()
        # stupid computation to determine new_unit (simpler for a MWE)
        new_unit = old_toto_unit
        new_toto = Toto(value, new_unit)
        return new_toto.value_like

    @staticmethod
    def axisinfo(unit, axis):
        return units.AxisInfo(label=str(unit))

    @staticmethod
    def default_units(x, axis):
        'Return the default unit for x or None'
        return getattr(x, 'unit_like', None)

最后,我将类的转换接口添加到matplotlib的转换接口注册表中:

units.registry[Toto] = TotoConverter()

然后是问题: 此时,在绘制Toto实例时应该在标签上显示单位,但是得到的结果与定义和注册单位转换接口之前的结果相同。这是为什么 ?

我怀疑因为我的Toto实例已转换为ndarray,所以从未调用过转换对象,但我不确定

欢呼

1 个答案:

答案 0 :(得分:0)

我怀疑您的意思是这样的,其中您有一个Toto的列表/数组,而不是Toto的值。

import matplotlib.pyplot as plt
import matplotlib.units as units

class Toto:
    def __init__(self, value_like, unit_like):
        self.value_like = value_like # typically a scalar or array
        self.unit_like = unit_like # a string describing the unit

class TotoConverter(units.ConversionInterface):

    @staticmethod
    def convert(value, unit, axis):
        if isinstance(value, Toto):
            return value.value_like
        else:
            return [toto.value_like for toto in value]

    @staticmethod
    def axisinfo(unit, axis):
        return units.AxisInfo(label=str(unit))

    @staticmethod
    def default_units(x, axis):
        'Return the default unit for x or None'
        if isinstance(x, Toto):
            return getattr(x, 'unit_like', None)
        else:
            return getattr(x[0], 'unit_like', None)

然后注册并使用它,

units.registry[Toto] = TotoConverter()


arr_x = [Toto(i, "meter") for i in range(5)]
arr_y = [Toto(i, "second") for i in range(5)]

plt.plot(arr_x, arr_y)            #use lists of Totos
plt.axhline(Toto(2, "second"))    # use Toto scalars
plt.xlim(Toto(-1, "meter"), None) # use Toto scalars

plt.show()