如何使用一个数组的索引定义另一个的__getitem __?

时间:2018-12-09 21:48:55

标签: python oop slice

我有一个自定义类Field的对象,该对象实际上包裹着numpy.ndarray对象。该对象由两个输入定义:一个值数组(values)和一个切片对象(segment),该对象定义将这些值放置在较大数组(grid)中的位置。

我希望能够使用grid的索引来访问values的项目。应该可以通过定义自定义Field.__getitem__方法来实现。

import numpy as np

class Field:
    def __init__(self, values, segment, grid):
        if (not isinstance(segment, slice)) \\
        or (not isinstance(values, np.ndarray)) :
            raise TypeError
        if segment.step not in [1, -1]:
            raise ValueError('Segment must be continuous')
        if len(grid[segment]) != len(values):
            raise ValueError('values length must match segment')

        self.values = values
        self.segment = segment 
        self.grid = grid

    def __getitem__(self, key):
        new_key = ...  # <--- Code goes here
        return self.values[new_key]

grid = np.array([0.5, 1.5, 2.5, 3.5, 4.5])

values = np.array([42., 43., 44.])
segment = slice(2, 5)

my_field = Field(values, segment, grid)
print(grid[segment])  # output: [2.5, 3.5, 4.5]
print(my_field[2])  # Desired output: 42.
print(my_field[3])  # Desired output: 43.
print(my_field[0])  # Desired output: IndexError

要点是segment定义了gridmy_field中定义的一组位置。 我解决这个问题的方法非常笨拙,并且基于定义布尔值index = np.zeros_like(grid, dtype=bool); index[segment] = True的数组,然后使用np.cumsum(index)的一些技巧...

如何以更简单的方式实现此行为?

1 个答案:

答案 0 :(得分:1)

您可以通过一个明确的步骤定义切片:

z

这是为了确保您segment = slice(2, 5, 1) 中的segment.step返回__init__。然后定义一个方法来检查您的输入1是否在适当的key中:

range

这给出了:

def __getitem__(self, key):
    start, stop = self.segment.start, self.segment.stop
    new_key = key - start
    if new_key not in range(stop - start):
        raise IndexError(f'Key must be in range({start}, {stop})')
    return self.values[new_key]