Python“数组索引太多”

时间:2016-05-11 12:41:51

标签: python numpy

我正在使用pandas在python中读取文件,然后将其保存在numpy数组中。 该文件的尺寸为11303402行×10列。 我需要拆分数据以进行交叉验证,为此我将数据切成11303402行x 9列示例和1个11303402行x 1列标签。 以下是代码:

tdata=pd.read_csv('train.csv')
tdata.columns='Arrival_Time','Creation_Time','x','y','z','User','Model','Device','sensor','gt']

User_Data = np.array(tdata)
features = User_Data[:,0:9]
labels = User_Data[:,9:10]

错误来自以下代码:

classes=np.unique(labels)
idx=labels==classes[0]
Yt=labels[idx]
Xt=features[idx,:]

在线:

Xt=features[idx,:]

它表示'数组'的索引太多了

所有3个数据集的形状为:

print np.shape(tdata) = (11303402, 10)
print np.shape(features) = (11303402, 9)
print np.shape(labels) = (11303402, 1)

如果有人知道这个问题,请帮助。

1 个答案:

答案 0 :(得分:5)

问题是idx具有形状(11303402,1),因为逻辑比较返回与labels形状相同的数组。这两个维度使用features中的所有索引。快速解决方法是

Xt=features[idx[:,0],:]