我有一个类似的列表:
[[0,1,2], [1,2,3], [2,3,4], [3,4,5]]
我可以把它变成一个像:
这样的数组array([[0,1,2],
[1,2,3],
[2,3,4],
[3,4,5]])
所以我一共有4行,每行有3列。现在我想找到大于2的所有元素的索引,因此对于整个矩阵,索引应该是:
((1,2),(2,1),(2,2),(3,1),(3,2),(3,3))
然后对于每一行,我会随机选出一个col索引,表示一个大于2的值。现在我的代码就像:
a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]]
out = np.ones(4)*-1
cur_row = 0
col_list = []
for r,c in np.nonzero(a>2):
if r == cur_row:
col_list.append(c)
else:
cur_row = r
shuffled_list = shuffle(col_list)
out[r-1] = shuffled_list[0]
col_list = []
col_list.append(c)
我希望得到一个看起来像:
array([-1, 2, 1, 2])
但是,现在当我运行我的代码时,它会显示
ValueError: too many values to unpack
任何人都知道我如何解决这个问题?或者我该如何实现我的目标?我只想尽快运行代码,所以任何其他好的想法也非常受欢迎。
答案 0 :(得分:3)
试试这个。
import numpy as np
arr = np.array([[0,1,2],
[1,2,3],
[2,3,4],
[3,4,5]])
indices = np.where(arr>2)
for r, c in zip(*indices):
print(r, c)
打印
1 2
2 1
2 2
3 0
3 1
3 2
所以,它应该工作。您也可以使用itertools.izip
,在这种情况下它甚至可以是更好的选择。
纯粹的numpy
解决方案(感谢@AshwiniChaudhary提出的建议):
for r, c in np.vstack(np.where(arr>2)).T:
...
虽然我不确定这会比使用izip或zip更快。
答案 1 :(得分:0)
您可以将数组与您的值进行比较,并使用where。
a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]])
np.where(a>2)
(array([1,2,2,3,3,3],dtype = int64),array([2,1,2,0,1,2], D型= int64类型))
获取你的元组
list(zip(*np.where(a>2)))
[(1,2),(2,1),(2,2),(3,0),(3,1),(3,2)]
答案 2 :(得分:0)
我已经完成了,代码应该是:
a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]])
out = np.ones(4)*-1
cur_row = 0
col_list = []
for r,c in zip(*(np.nonzero(a>2))):
if r == cur_row:
col_list.append(c)
else:
cur_row = r
shuffle(col_list)
if len(col_list) == 0:
out[r-1] = -1
else:
out[r-1] = col_list[0]
col_list = []
col_list.append(c)
shuffle(col_list)
if len(col_list) == 0:
out[len(out)-1] = -1
else:
out[len(out)-1] = col_list[0]
最后但在forloop之外的部分是为了确保最后一行将被处理。
它适用于我的情况。