RandomForest IndexError:只有整数,切片(`:`),省略号(```),numpy.newaxis(`None`)和整数或布尔数组才是有效索引

时间:2017-06-21 14:37:57

标签: python numpy scikit-learn classification random-forest

我在sklearn上使用RandomForestClassifier

class RandomForest(RandomForestClassifier):

    def fit(self, x, y):
        self.unique_train_y,  y_classes = transform_y_vectors_in_classes(y)
        return RandomForestClassifier.fit(self, x, y_classes)

    def predict(self, x):
        y_classes = RandomForestClassifier.predict(self, x)
        predictions = transform_classes_in_y_vectors(y_classes, self.unique_train_y)
        return predictions

    def transform_classes_in_y_vectors(y_classes, unique_train_y):
        cyr = [unique_train_y[predicted_index] for predicted_index in y_classes]
        predictions = np.array(float(cyr))
        return predictions

我收到此错误消息:

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

1 个答案:

答案 0 :(得分:1)

似乎y_classes包含的值不是有效的索引。

当您尝试使用unique_train_y访问predicted_index而不是获得异常时,因为predict_index不是您认为的那样。

尝试执行以下代码:

cyr = [unique_train_y[predicted_index] for predicted_index in range(len(y_classes))] 
# assuming unique_train_y is a list and predicted_index should be integer.