Numpy按原始索引取消过滤数据

时间:2018-11-24 07:24:27

标签: python numpy scipy

我有多维数据X及其对应的标签y。我通过以下方式按类过滤X;

import numpy as np

np.random.seed(0)
X = np.random.rand(220,22,1125)
y = np.random.randint(4, size=(220))


# Class indexes
index_class0 = np.where(y==0)[0]
index_class1 = np.where(y==1)[0]
index_class2 = np.where(y==2)[0]
index_class3 = np.where(y==3)[0]

# Filtering X by classes
X0 = X[index_class0,:,:]
X1 = X[index_class1,:,:]
X2 = X[index_class2,:,:]
X3 = X[index_class3,:,:]

# Assume some operations are performed on X0-X3
# TODO: reconstruct X using X0-X3, having same class indexes.

现在给定X0,X1,X2和X3以及相应的类索引,在类顺序保持不变的情况下,如何重构X?

1 个答案:

答案 0 :(得分:0)

我们已经知道原始数组中不同类的索引。因此,只需创建一个空数组,然后将类的X放在正确的位置即可。

reconstructed_X = np.zeros(X.shape)

reconstructed_X[index_class0] = X0
reconstructed_X[index_class1] = X1
reconstructed_X[index_class2] = X2
reconstructed_X[index_class3] = X3