我正在使用pandas在python中读取文件,然后将其保存在numpy数组中。 该文件的尺寸为11303402行×10列。 我需要拆分数据以进行交叉验证,为此我将数据切成11303402行x 9列示例和1个11303402行x 1列标签。 以下是代码:
tdata=pd.read_csv('train.csv')
tdata.columns='Arrival_Time','Creation_Time','x','y','z','User','Model','Device','sensor','gt']
User_Data = np.array(tdata)
features = User_Data[:,0:9]
labels = User_Data[:,9:10]
错误来自以下代码:
classes=np.unique(labels)
idx=labels==classes[0]
Yt=labels[idx]
Xt=features[idx,:]
在线:
Xt=features[idx,:]
它表示'数组'的索引太多了
所有3个数据集的形状为:
print np.shape(tdata) = (11303402, 10)
print np.shape(features) = (11303402, 9)
print np.shape(labels) = (11303402, 1)
如果有人知道这个问题,请帮助。
答案 0 :(得分:5)
问题是idx
具有形状(11303402,1)
,因为逻辑比较返回与labels
形状相同的数组。这两个维度使用features
中的所有索引。快速解决方法是
Xt=features[idx[:,0],:]