我不太了解索引操作,您能解释一下这行代码吗?
train = data[ranks>=test_points]
在此功能中
def random_split(data,test_points):
ranks = np.arange(data.shape[0])
np.random.shuffle(ranks)
train = data[ranks>=test_points]
return train
所以我需要这样分割数据:一半用于训练,四分之一用于验证,四分之一用于测试。所以我这样做是这样的:
def random_split(data,test_points):
ranks = np.arange(data.shape[0])
np.random.shuffle(ranks)
train = data[ranks>=test_points]
other = data[ranks<test_points]
test = other[ranks>=int(test_points/4)]
valid = other[ranks<int(test_points/4)]
return train,test,valid
它不起作用,怎么了?您能帮我理解这段代码吗?
答案 0 :(得分:2)
问题在于,在other = data[ranks<test_points]
之后,变量other
和rank
的大小不再相同,因此会出现错误。您可以使用类似的
train_size = 500
validation_size = 100
train_set = data[:train_size]
validation_set = data[train_size: train_size + validation_size]
test_set = data[train_size + validation_size:]
注意:x[ i < 10]
样式索引特定于numpy。在一般的python中是不允许的。 <
被重载以返回布尔数组,例如
i = np.array([1, 3, 5, 4])
i <= 4 # return [True True False True]
在numpy中称为逻辑索引。