我正在尝试构建决策树分类器,我有以下代码:
def dtree(data, attrs, target):
data = data[:]
vals = []
for entry in data:
entry_index = attrs.index(target)
vals.append(entry[entry_index])
major = majority(data, attrs, target)
if not data or (len(attrs) - 1) <= 0:
return major
elif vals.count(vals[0]) == len(vals):
return vals[0]
else:
pick = choose(data, attrs, target)
tree = {pick:{}}
for each in get_vals(data, attrs, pick):
new_d = get_data(data, attrs, pick, each)
newAttr = attrs[:]
newAttr.remove(pick)
subtree = dtree(new_d, newAttr, target)
tree[pick][each] = subtree
return tree
其中:
data
是我的培训数据pandas
的{{1}}数据框,(33582 x 21)
是数据框标题的列表,attrs
是目标属性的字符串名称。target
是一个列表当我调用此方法时,我收到以下错误:
vals
我不确定这条线是什么引发了错误而且我不知道我应该做些什么来诊断它。
答案 0 :(得分:1)
因此,代码的那部分出现错误:
for entry in data:
entry_index = attrs.index(target)
vals.append(entry[entry_index])
我想,你想在这里做的是迭代data
DataFrame的所有行,并从每一行添加列target
的值到列表vals
。问题正在发生,因为迭代数据会返回列名(字符串),而不是行。因此,当您索引entry
字符串,索引为target
列时,您会获得IndexError
。
在pandas中,有更好的方法可以将列的所有值都列出来:
data[target].tolist()