我想使用rss

时间:2019-05-28 07:49:51

标签: python tree regression decision-tree

1.2 RSS 如教科书第8章所述,我们需要一个参数来衡量特定拆分/树的效率。我们在这里选择残差平方和。知道分割的元素的element [-1]中包含要预测的值(Wage(k)),从而对分割的列表实施RSS计算。 您可以使用实现下方的代码单元检查特定拆分的结果。

1.3拆分 我们将编写一个拆分函数,该函数能够根据特征的索引,拆分值和数据将数据分为两部分。实现拆分条件,以惯例为准

1.4最佳拆分创建 没有理论上的结果允许在进行所有可能的分割之前找到最佳的分割,因此我们在整个分割上实现了RSS最小化器。使用先前编码的功能,填充#TODO零件。您可以在以下单元格中查看退货。

1.5树的建立和预测 现在,通过汇总代码的所有部分,我们可以递归地构建整个树。注释给定的代码,尤其是参数min_size对模型结构的重要性 使用相同的编码范例,我们可以使用我们的模型对测试集进行回归,就像您在下一个代码单元中看到的那样。现在缺少全球模型的哪一部分?在真实的机器学习问题中说明其重要性。 (奖金)实施

我想使用rss实现基于树的回归模型。我想填写以下空白,但这太困难了

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math

data = pd.read_csv("Wages.csv", sep=";")

training_set = np.array(data[:10])
test_set = np.array(data[10:])


-- RSS --

verbose = False
def RSS(splits):
    """
    Return the RSS of the input splits. The input should be in the form 
of a list of list
    """
    residual = 0
    for split in splits:
        if(len(split) != 0):
            mean = mean(split[:-1])
            if(verbose):print("Mean :" + str(mean))
            residual = ##TODO
    return residual

split_1 = np.array([[[0,2],[0,8]],[[4,5]]])
RSS_value = RSS(split_1)
if (type(RSS_value) not in [int,float,np.float16,np.float32,np.float64]):
    print("TypeError : check your output")
elif(RSS(split_1) == 18.0):
    print("Your calculations are right, at least on this specific 
example")
else:
    print("Your calculations are wrong")


-- Split --

def split(index, value, data):
    """
    Splits the input @data into two parts, based on the feature at @index 
position, using @value as a boundary value
    """
    left_split = #TODO condition
    right_split = #TODO condition
    return [left_split, right_split]


-- optimal split creation

def split_tester(data):
    """
    Find the best possible split possible for the current @data.
    Loops over all the possible features, and all values for the given 
features to test every possible split
    """
optimal_split_ind, optimal_split_value, optimal_residual, optimal_splits = -1,-1,float("inf"),[] #Initialize such that the first split is better than initialization
for curr_ind in range(data.shape[1]-1):
    for curr_val in data:
        if(verbose):print("Curr_split : " + str((curr_ind, curr_val[curr_ind])))
        split_res = #TODO (comments : get the current split)

        if(verbose):print(split_res)
        residual_value = #TODO (comments : get the RSS of the current split)

        if(verbose):print("Residual : " + str(residual_value))
        if residual_value < optimal_residual:
            optimal_split_ind, optimal_split_value, optimal_residual, optimal_splits = curr_ind,\
                                                                curr_val[curr_ind], residual_value, split_res

return optimal_split_ind, optimal_split_value, optimal_splits



-- tree building --


def tree_building(data, min_size):
    """
    Recursively builds a tree using the split tester built before.
    """
    if(data.shape[0] > min_size):
        ind, value, [left, right] = split_tester(data)
        left, right = np.array(left), np.array(right)
        return [tree_building(left, min_size), tree_building(right, 
min_size),ind,value]
    else:
        return data


tree = tree_building(training_set,2)




def predict(tree, input_vector):
    if(type(tree[-1]) != np.int64):
        if(len(tree) == 1):
            return(tree[0][-1])
        else:
            return(np.mean([element[-1] for element in tree]))
    else:
        left_tree, right_tree, split_ind, split_value = tree
        if(input_vector[split_ind]<split_value):
            return predict(left_tree, input_vector)
        else:
            return predict(right_tree, input_vector)



for employee in test_set:
    print("Predicted : " + str(predict(tree,employee)) + ", Actual : " + 
str(employee[-1]))

我正在研究代码以在此处获取#TODO。我不知道。请帮助我。

1 个答案:

答案 0 :(得分:0)

如果我理解正确,那么您只要求在发布的代码中标记为#TODO的计算即可。如果您计算模型预测的误差,则这些误差值有时称为“残余误差”。您不能简单地将它们相加,有些是负数,有些是正数,因此它们可能会相互抵消。但是,如果误差均被平方,则平方误差均为正值,可以求和。这就是术语“残差平方和”(RSS)的来源。您可以使用“ RSS = numpy.sum(numpy.square(errors))”之类的值来计算该值。