为什么NumPy将对象转换为浮点数?

时间:2017-08-28 11:07:00

标签: python class numpy intervals

我试图在NumPy数组中存储间隔(使用其特定算法)。如果我使用自己的Interval类,它可以工作,但是我的课很差,而且我的Python知识有限。

我知道pyInterval并且它非常完整。它涵盖了我的问题。唯一不起作用的是在NumPy数组中存储pyInterval对象。

class Interval(object):

    def __init__(self, lower, upper = None):
        if upper is None:
            self.upper = self.lower = lower
        elif lower <= upper:
            self.lower = lower
            self.upper = upper
        else:
            raise ValueError(f"Lower is bigger than upper! {lower},{upper}")

    def __repr__(self):
        return "Interval " + str((self.lower,self.upper))

    def __mul__(self,another):
        values = (self.lower * another.lower,
                        self.upper * another.upper,
                        self.lower * another.upper,
                        self.upper * another.lower)
        return Interval(min(values), max(values))

import numpy as np
from interval import interval

i = np.array([Interval(2,3), Interval(-3,6)], dtype=object) # My class
ix = np.array([interval([2,3]), interval([-3,6])], dtype=object) # pyInterval

这些是结果

In [30]: i
Out[30]: array([Interval (2, 3), Interval (-3, 6)], dtype=object)

In [31]: ix
Out[31]: 
array([[[2.0, 3.0]],

       [[-3.0, 6.0]]], dtype=object)

来自pyInterval的间隔已被转换为浮点列表的列表。如果他们保留区间算术,那就不成问题了......

In [33]: i[0] * i[1]
Out[33]: Interval (-9, 18)

In [34]: ix[0] * ix[1]
Out[34]: array([[-6.0, 18.0]], dtype=object)

Out[33]是希望的输出。使用pyInterval的输出不正确。显然使用原始pyInterval它就像一个魅力

In [35]: interval([2,3]) * interval([-3,6])
Out[35]: interval([-9.0, 18.0])

Here是pyInterval源代码。我不明白为什么使用这个对象NumPy并不像我期望的那样工作。

2 个答案:

答案 0 :(得分:2)

公平地说,numpy.ndarray构造函数很难推断应该将哪种数据放入其中。它接收类似于元组列表的对象并使用它。

但是,您可以通过不猜测数据的形状来帮助您的构造函数:

a = interval([2,3])
b = interval([-3,6])
ll = [a,b]
ix = np.empty((len(ll),), dtype=object)
ix[:] = [*ll]
ix[0]*ix[1] #interval([-9.0, 18.0])

答案 1 :(得分:1)

NumPy将每个区间视为两个数字的数组,并且它执行元素乘法,这是您不想要的。试试这个:

interval.__mul__(ix[0], ix[1])

这是您要调用的函数的直接调用。它应该给你你需要的答案,即使它不是很漂亮。要将它变成适用于整个数组的东西,你可以这样做:

itvmul = np.vectorize(interval.__mul__)

这将允许您对区间数组进行元素乘法运算:https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.vectorize.html