我有一个想要在numpy数组的列上测试谓词的函数,假设它们总计为10.该函数将采用1D或2D数组,其中1D数组被视为a单列。
对于2D情况,我可以这样做:
python
for col in two_dim_array.T:
assert sum(col) == 10
我知道1D案例我可以做到:
python
assert sum(one_dim_array) == 10
但是有没有办法让一个代码路径与数组的类型无关,即我不必打开len(my_array.shape)
并使用上面的任何一个代码片段,例如:
python
for col in one_or_two_dim_array.cols():
assert sum(col) == 10
对于1D情况,我们只会通过循环一次。
答案 0 :(得分:2)
在以下两种情况下,下面会产生一列列和:
column_totals = one_or_two_dim_array.sum(axis=0).flatten()
如果需要,您可以循环遍历column_totals
中的值,或者一次性断言所有比较:
assert np.all(column_totals == 10)
事实上,整个事情可以缩写为一行:
assert np.all(one_or_two_dim_array.sum(axis=0) == 10)
答案 1 :(得分:0)
你的意思是什么?
import numpy as np
def test(arr):
if np.ndim(arr) > 1:
arr = arr.T
for col in arr:
assert np.sum(col) == 10
arr1 = np.array([1,2,3])
arr2 = np.array([[1,2,3],[4,5,6]])
print(test(arr1))
print
print(test(arr2))