我想知道如何在Python中实现2D数组切片?
例如,
arr
是自定义类2D数组的实例。
如果我想在这个对象上启用2D切片语法,如下所示:
arr[:,1:3] #retrieve the 1 and 2 column values of every row
或
arr[,:3] #retrieve the 1 and 2 column values of every row
用法和语法就像numpy.array。但是,如何才能实现这种功能呢?
PS:
我的想法:
对于第一种情况,[:,1:3]
部分就像是两个切片的元组
然而,对于第二种情况[,1:3]
似乎很神秘。
答案 0 :(得分:7)
如果您想了解数组切片的规则,下面的图片可能会有所帮助:
答案 1 :(得分:2)
对于读取权限,您需要覆盖__getitem__
方法:
class ArrayLike(object):
def __init__(self):
pass
def __getitem__(self, arg):
(rows,cols) = arg # unpack, assumes that we always pass in 2-arguments
# TODO: parse/interpret the rows/cols parameters,
# for single indices, they will be integers, for slices, they'll be slice objects
# here's a dummy implementation as a placeholder
return numpy.eye(10)[rows, cols]
其中一个棘手的问题是__getitem__
总是只使用一个参数(除了self),
当你在方括号中放入多个逗号分隔的项时,你实际上提供了一个元组作为__getitem__
调用的参数;因此需要
解压缩这个元组(并且可选地验证元组的长度是否合适)
功能。
现在给出a = ArrayLike()
,你最终得到了
a[2,3]
表示rows=2
,cols=3
a[:3,2]
表示rows=slice(None, 3, None)
,cols=3
等等;你必须看documentation on slice objects来决定 如何使用切片信息从您的班级中提取所需的数据。
为了使它更像一个numpy数组,你也想要覆盖__setitem__
,以
允许分配元素/切片。
答案 2 :(得分:1)
obj[,:3]
无效python因此会引发SyntaxError
- 因此,无法在源文件中拥有该语法。 (当您尝试在numpy
数组上使用它时,它会失败)
答案 3 :(得分:0)
如果它是你自己的类并且你愿意传入一个字符串,那么这是一个hack。
How to override the [] operator?
class Array(object):
def __init__(self, m, n):
"""Create junk demo array."""
self.m = m
self.n = n
row = list(range(self.n))
self.array = map(lambda x:row, range(self.m))
def __getitem__(self, index_string):
"""Implement slicing/indexing."""
row_index, _, col_index = index_string.partition(",")
if row_index == '' or row_index==":":
row_start = 0
row_stop = self.m
elif ':' in row_index:
row_start, _, row_stop = row_index.partition(":")
try:
row_start = int(row_start)
row_stop = int(row_stop)
except ValueError:
print "Bad Data"
else:
try:
row_start = int(row_index)
row_stop = int(row_index) + 1
except ValueError:
print "Bad Data"
if col_index == '' or col_index == ":":
col_start = 0
col_stop = self.n
elif ':' in col_index:
col_start, _, col_stop = col_index.partition(":")
try:
col_start = int(col_start)
col_stop = int(col_stop)
except ValueError:
print "Bad Data"
else:
try:
col_start = int(col_index)
col_stop = int(col_index) + 1
except ValueError:
print "Bad Data"
return map(lambda x: self.array[x][col_start:col_stop],
range(row_start, row_stop))
def __str__(self):
return str(self.array)
def __repr__(self):
return str(self.array)
array = Array(4, 5)
print array
out: [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]
array[",1:3"]
out: [[1, 2], [1, 2], [1, 2], [1, 2]]
array[":,1:3"]
out: [[1, 2], [1, 2], [1, 2], [1, 2]]