角度的双线性插值

时间:2017-10-31 11:21:18

标签: python-3.x bilinear-interpolation

我有一个2d方向数据数组。我需要在更高分辨率的网格上进行插值,但像scipy interp2d这样的现成函数不会解释0到360之间的不连续性。

我有4个单点网格的代码(感谢How to perform bilinear interpolation in PythonRotation Interpolation)但是我希望它能同时接受大数据集 - 就像interp2d函数一样。如何以一种不会遍历所有数据的方式将其包含在下面的代码中?

谢谢!

def shortest_angle(beg,end,amount):
    shortest_angle=((((end - beg) % 360) + 540) % 360) - 180
    return shortest_angle*amount    

def bilinear_interpolation_rotation(x, y, points):
    '''Interpolate (x,y) from values associated with four points.

    The four points are a list of four triplets:  (x, y, value).
    The four points can be in any order.  They should form a rectangle.
    '''

    points = sorted(points)               # order points by x, then by y
    (x1, y1, q11), (_x1, y2, q12), (x2, _y1, q21), (_x2, _y2, q22) = points

    if x1 != _x1 or x2 != _x2 or y1 != _y1 or y2 != _y2:
        raise ValueError('points do not form a rectangle')
    if not x1 <= x <= x2 or not y1 <= y <= y2:
        raise ValueError('(x, y) not within the rectangle')
    # interpolate over the x value at each y point
    fxy1 = q11 + shortest_angle(q11,q21,((x-x1)/(x2-x1)))
    fxy2 = q12 + shortest_angle(q12,q22,((x-x1)/(x2-x1)))    
    # interpolate over the y values 
    fxy = fxy1 + shortest_angle(fxy1,fxy2,((y-y1)/(y2-y1)))

    return fxy

1 个答案:

答案 0 :(得分:1)

我将在此示例中重复使用一些个人PointPoint3D简化类:

Point

class Point:
    #Constructors
    def __init__(self, x, y):
        self.x = x
        self.y = y

    # Properties
    @property
    def x(self):
        return self._x

    @x.setter
    def x(self, value):
        self._x = float(value)

    @property
    def y(self):
        return self._y

    @y.setter
    def y(self, value):
        self._y = float(value)

    # Printing magic methods
    def __repr__(self):
        return "({p.x},{p.y})".format(p=self)

    # Comparison magic methods
    def __is_compatible(self, other):
        return hasattr(other, 'x') and hasattr(other, 'y')

    def __eq__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x == other.x) and (self.y == other.y)

    def __ne__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x != other.x) or (self.y != other.y)

    def __lt__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y) < (other.x, other.y)

    def __le__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y) <= (other.x, other.y)

    def __gt__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y) > (other.x, other.y)

    def __ge__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y) >= (other.x, other.y) 

它代表2D点。它有一个简单的构造函数,xy属性,可确保它们始终存储float s,字符串表示的魔术方法为(x,y),并进行比较以使它们可排序(按类型排序) x,然后是y)。我的原始类还有其他功能,例如加法和减法(矢量行为)魔术方法,但这个例子不需要它们。

Point3D

class Point3D(Point):
    # Constructors
    def __init__(self, x, y, z):
        super().__init__(x, y)
        self.z = z

    @classmethod
    def from2D(cls, p, z):
        return cls(p.x, p.y, z)

    # Properties
    @property
    def z(self):
        return self._z

    @z.setter
    def z(self, value):
        self._z = (value + 180.0) % 360 - 180

    # Printing magic methods
    def __repr__(self):
        return "({p.x},{p.y},{p.z})".format(p=self)

    # Comparison magic methods
    def __is_compatible(self, other):
        return hasattr(other, 'x') and hasattr(other, 'y') and hasattr(other, 'z')

    def __eq__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x == other.x) and (self.y == other.y) and (self.z == other.z)

    def __ne__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x != other.x) or (self.y != other.y) or (self.z != other.z)

    def __lt__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y, self.z) < (other.x, other.y, other.z)

    def __le__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y, self.z) <= (other.x, other.y, other.z)

    def __gt__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y, self.z) > (other.x, other.y, other.z)

    def __ge__(self, other):
        if not self.__is_compatible(other):
            return NotImplemented
        return (self.x, self.y, self.z) >= (other.x, other.y, other.z)

Point相同,但适用于3D点。它还包括一个额外的构造函数类方法,它将Point及其z值作为参数。

线性插值

def linear_interpolation(x, *points, extrapolate=False):
    # Check there are a minimum of two points
    if len(points) < 2:
        raise ValueError("Not enought points given for interpolation.")
    # Sort the points
    points = sorted(points)
    # Check that x is the valid interpolation interval
    if not extrapolate and (x < points[0].x or x > points[-1].x):
        raise ValueError("{} is not in the interpolation interval.".format(x))
    # Determine which are the two surrounding interpolation points
    if x < points[0].x:
        i = 0
    elif x > points[-1].x:
        i = len(points)-2
    else:
        i = 0
        while points[i+1].x < x:
            i += 1
    p1, p2 = points[i:i+2]
    # Interpolate
    return Point(x, p1.y + (p2.y-p1.y) * (x-p1.x) / (p2.x-p1.x))

它需要第一个位置参数来确定我们想要计算的y值的x,以及我们想要插入的无限量的Point个实例。关键字参数(extrapolate)允许打开外推。返回Point实例,其中包含请求的x和计算的y值。

双线性插值

我提供两种选择,它们都具有与先前插值函数类似的特征。我们想要计算其z值的Point,关键字参数(extrapolate)打开外推并返回带有请求和计算数据的Point3D实例。这两种方法之间的区别在于如何提供用于插值的值:

方法1

第一种方法采用两级深度嵌套dict。第一级键表示x值,第二级键表示y值,第二级键表示z值。

def bilinear_interpolation(p, points, extrapolate=False):
    x_values = sorted(points.keys())
    # Check there are a minimum of two x values
    if len(x_values) < 2:
        raise ValueError("Not enought points given for interpolation.")
    y_values = set()
    for value in points.values():
        y_values.update(value.keys())
    y_values = sorted(y_values)
    # Check there are a minimum of two y values
    if len(y_values) < 2:
        raise ValueError("Not enought points given for interpolation.")
    # Check that p is in the valid interval
    if not extrapolate and (p.x < x_values[0] or p.x > x_values[-1] or p.y < y_values[0] or p.y > y_values[-1]):
        raise ValueError("{} is not in the interpolation interval.".format(p))
    # Determine which are the four surrounding interpolation points
    if p.x < x_values[0]:
        i = 0
    elif p.x > x_values[-1]:
        i = len(x_values) - 2
    else:
        i = 0
        while x_values[i+1] < p.x:
            i += 1
    if p.y < y_values[0]:
        j = 0
    elif p.y > y_values[-1]:
        j = len(y_values) - 2
    else:
        j = 0
        while y_values[j+1] < p.y:
            j += 1
    surroundings = [
                    Point(x_values[i  ], y_values[j  ]),
                    Point(x_values[i  ], y_values[j+1]),
                    Point(x_values[i+1], y_values[j  ]),
                    Point(x_values[i+1], y_values[j+1]),
                   ]
    for i, surrounding in enumerate(surroundings):
        try:
            surroundings[i] = Point3D.from2D(surrounding, points[surrounding.x][surrounding.y])
        except KeyError:
            raise ValueError("{} is missing in the interpolation grid.".format(surrounding))
    p1, p2, p3, p4 = surroundings
    # Interpolate
    p12 = Point3D(p1.x, p.y, linear_interpolation(p.y, Point(p1.y,p1.z), Point(p2.y,p2.z), extrapolate=True).y)
    p34 = Point3D(p3.x, p.y, linear_interpolation(p.y, Point(p3.y,p3.z), Point(p4.y,p4.z), extrapolate=True).y)
    return Point3D(p.x, p12.y, linear_interpolation(p.x, Point(p12.x,p12.z), Point(p34.x,p34.z), extrapolate=True).y)


print(bilinear_interpolation(Point(2,3), {1: {2: 5, 4: 6}, 3: {2: 3, 4: 9}}))

方法2

第二种方法需要无数个Point3D实例。

def bilinear_interpolation(p, *points, extrapolate=False):
    # Check there are a minimum of four points
    if len(points) < 4:
        raise ValueError("Not enought points given for interpolation.")
    # Sort the points into a grid
    x_values = set()
    y_values = set()
    for point in sorted(points):
        x_values.add(point.x)
        y_values.add(point.y)
    x_values = sorted(x_values)
    y_values = sorted(y_values)
    # Check that p is in the valid interval
    if not extrapolate and (p.x < x_values[0] or p.x > x_values[-1] or p.y < y_values[0] or p.y > y_values[-1]):
        raise ValueError("{} is not in the interpolation interval.".format(p))
    # Determine which are the four surrounding interpolation points
    if p.x < x_values[0]:
        i = 0
    elif p.x > x_values[-1]:
        i = len(x_values) - 2
    else:
        i = 0
        while x_values[i+1] < p.x:
            i += 1
    if p.y < y_values[0]:
        j = 0
    elif p.y > y_values[-1]:
        j = len(y_values) - 2
    else:
        j = 0
        while y_values[j+1] < p.y:
            j += 1
    surroundings = [
                    Point(x_values[i  ], y_values[j  ]),
                    Point(x_values[i  ], y_values[j+1]),
                    Point(x_values[i+1], y_values[j  ]),
                    Point(x_values[i+1], y_values[j+1]),
                   ]
    for point in points:
        for i, surrounding in enumerate(surroundings):
            if point.x == surrounding.x and point.y == surrounding.y:
                surroundings[i] = point
    for surrounding in surroundings:
        if not isinstance(surrounding, Point3D):
            raise ValueError("{} is missing in the interpolation grid.".format(surrounding))
    p1, p2, p3, p4 = surroundings
    # Interpolate
    p12 = Point3D(p1.x, p.y, linear_interpolation(p.y, Point(p1.y,p1.z), Point(p2.y,p2.z), extrapolate=True).y)
    p34 = Point3D(p3.x, p.y, linear_interpolation(p.y, Point(p3.y,p3.z), Point(p4.y,p4.z), extrapolate=True).y)
    return Point3D(p.x, p12.y, linear_interpolation(p.x, Point(p12.x,p12.z), Point(p34.x,p34.z), extrapolate=True).y)


print(bilinear_interpolation(Point(2,3), Point3D(3,2,3), Point3D(1,4,6), Point3D(3,4,9), Point3D(1,2,5)))

您可以从两种方法中看到他们使用之前定义的linear_interpoaltion函数,并且他们始终将extrapolation设置为True,因为它们已经引发了异常{{1并且请求的点超出了提供的时间间隔。