让NoBroadcastArray
成为np.ndarray
的子类。如果x
是NoBroadcastArray
的实例,而arr
是np.ndarray
的实例,那么我要
x[slice] = arr
仅当arr.size
与slice的大小匹配时成功。
x[1] = 1 # should succeed
x[1:2] = 1 # should fail - scalar doesn't have size 2
x[1:2] = [1,2] # should succeed
x[1:2] = np.array([[1,2]]) # should succeed - shapes don't match but sizes do.
x[1:2, 3:4] = np.array([1,2]) # should fail - 1x2 array doesn't have same size as 2x2 array
换句话说,只有在RHS不必更改大小以适合LHS切片的情况下,分配才应成功。我不介意它是否会改变形状,例如如果它从形状为1x2的数组变为形状为2x1x1的数组。
我该如何实现这一目标?我现在尝试的路径是覆盖NoBroadcastArray
中的__setitem__,以将切片的大小与要设置的项目的大小进行匹配。事实证明这很棘手,所以我想知道是否有人有更好的主意,可能使用__array_wrap__或__array_finalize __。
答案 0 :(得分:1)
这是我想到的实现:
import numpy as np
class NoBroadcastArray(np.ndarray):
def __new__(cls, input_array):
return np.asarray(input_array).view(cls)
def __setitem__(self, args, value):
value = np.asarray(value, dtype=self.dtype)
expected_size = self._compute_expected_size(args)
if expected_size != value.size:
raise ValueError(("assigned value size {} does not match expected size {} "
"in non-broadcasting assignment".format(value.size, expected_size)))
return super(NoBroadcastArray, self).__setitem__(args, value)
def _compute_expected_size(self, args):
if not isinstance(args, tuple):
args = (args,)
# Iterate through indexing arguments
arr_dim = 0
ellipsis_dim = len(args)
i_arg = 0
size = 1
adv_idx_shapes = []
for i_arg, arg in enumerate(args):
if isinstance(arg, slice):
size *= self._compute_slice_size(arg, arr_dim)
arr_dim += 1
elif arg is Ellipsis:
ellipsis_dim = arr_dim
break
elif arg is np.newaxis:
pass
else:
adv_idx_shapes.append(np.shape(arg))
arr_dim += 1
# Go backwards from end after ellipsis if necessary
arr_dim = -1
for arg in args[:i_arg:-1]:
if isinstance(arg, slice):
size *= self._compute_slice_size(arg, arr_dim)
arr_dim -= 1
elif arg is Ellipsis:
raise IndexError("an index can only have a single ellipsis ('...')")
elif arg is np.newaxis:
pass
else:
adv_idx_shapes.append(np.shape(arg))
arr_dim -= 1
# Include dimensions under ellipsis
ellipsis_end_dim = arr_dim + self.ndim + 1
if ellipsis_dim > ellipsis_end_dim:
raise IndexError("too many indices for array")
for i_dim in range(ellipsis_dim, ellipsis_end_dim):
size *= self.shape[i_dim]
size *= NoBroadcastArray._advanced_index_size(adv_idx_shapes)
return size
def _compute_slice_size(self, slice, axis):
if axis >= self.ndim or axis < -self.ndim:
raise IndexError("too many indices for array")
size = self.shape[axis]
start = slice.start
stop = slice.stop
step = slice.step if slice.step is not None else 1
if step == 0:
raise ValueError("slice step cannot be zero")
if start is not None:
start = start if start >= 0 else start + size
start = min(max(start, 0), size - 1)
else:
start = 0 if step > 0 else size - 1
if stop is not None:
stop = stop if stop >= 0 else stop + size
stop = min(max(stop, 0), size)
else:
stop = size if step > 0 else -1
slice_size = stop - start
if step < 0:
slice_size = -slice_size
step = -step
slice_size = ((slice_size - 1) // step + 1 if slice_size > 0 else 0)
return slice_size
@staticmethod
def _advanced_index_size(shapes):
size = 1
if not shapes:
return size
dims = max(len(s) for s in shapes)
for dim_sizes in zip(*(s[::-1] + (1,) * (dims - len(s)) for s in shapes)):
d = 1
for dim_size in dim_sizes:
if dim_size != 1:
if d != 1 and dim_size != d:
raise IndexError("shape mismatch: indexing arrays could not be "
"broadcast together with shapes " + " ".join(map(str, shapes)))
d = dim_size
size *= d
return size
您将像这样使用它:
import numpy as np
a = NoBroadcastArray(np.arange(24).reshape(4, 3, 2, 1))
a[:] = 1
# ValueError: assigned value size 1 does not match expected size 24 in non-broadcasting assignment
a[:, ..., [0, 1], :] = 1
# ValueError: assigned value size 1 does not match expected size 16 in non-broadcasting assignment
a[[[0, 1], [2, 3]], :, [1, 0]] = 1
# ValueError: assigned value size 1 does not match expected size 12 in non-broadcasting assignment
这只会检查给定值的大小是否与索引匹配,但不会对值进行任何重塑,因此对于NumPy仍然照常工作(即可以添加其他外部尺寸)。
答案 1 :(得分:0)
这是一个较短的解决方案:
class FixedSizeSetitemArray(np.ndarray):
def __setitem__(self, index, value):
value = np.asarray(value)
current = self[index]
if value.shape != current.shape:
super().__setitem__(index, value)
elif value.size == current.size:
super().__setitem__(index, value.reshape(current.shape))
else:
old, new, cls = current.size, value.size, self.__class__.__name__
raise ValueError(f"{cls} will not broadcast in __setitem__ "
f"(expected size {old}, got size {new})")
虽然这符合给定的确切要求,但包括任意调整数组形状以适合给定的区域,这实际上可能不是理想的。例如,这将很乐意将形状为(2, 2, 2)
的数组重塑为(8,)
,反之亦然。要消除这种行为,只需取出elif
块。
如果只希望删除多余的尺寸,则可以使用np.squeeze
。
elif value.squeeze().shape == current.shape:
super().__setitem__(index, value.squeeze())
squeeze
上的某些其他变体将允许更广泛地删除额外的维度,但是如果遇到这种情况,修复您正在使用的索引可能是一个更好的主意。