我一直收到以下错误:AttributeError:'tuple'对象没有属性'size'。还添加了split功能
import numpy as np
def split(array):
N = {}
uniqe_array = np.unique(array)
for i in uniqe_array:
N[i] = np.where(array==i)
return N
def information_gain(x_array, y_array):
parent_entropy = entropy(x_array)
split_dict = split(y_array)
for val in split_dict.values():
freq = val.size / x_array.size
child_entropy = entropy([x_array[i] for i in val])
parent_entropy -= child_entropy* freq
return parent_entropy
x = np.array([0, 1, 0, 1, 0, 1])
y = np.array([0, 1, 0, 1, 1, 1])
print(round(information_gain(x, y), 4))
x = np.array([0, 0, 1, 1, 2, 2])
y = np.array([0, 1, 0, 1, 1, 1])
print(round(information_gain(x, y), 4))
答案 0 :(得分:1)
看来split_dict
的值是元组,而不是我认为应该是np.array
的值。我建议您看一下split
返回哪个功能,因为它可能正在创建元组而不是split_dict
。
根据函数np.array
内部的内容,它将split
返回到{0: (array([0, 2], dtype=int64),), 1: (array([1, 3, 4, 5], dtype=int64),)}
,因此值是包含split_dict
和数据类型(在这种情况下为{{ 1}})作为元素,因此提高了numpy.array
。
经过修改的int64
可以满足您的需求,
AttributeError
有关更多信息,请参见以下答案:What is the purpose of numpy.where returning a tuple?
答案 1 :(得分:0)
我认为函数的名称是len而不是size
objects = (1,2,3)
len(objects)
将输出设为3