在keras source code上,当准备来自sklearn的数据时,我们读到:
if len(y.shape) == 2 and y.shape[1] > 1:
self.classes_ = np.arange(y.shape[1])
elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
self.classes_ = np.unique(y)
y = np.searchsorted(self.classes_, y)
else:
raise ValueError('Invalid shape for y: ' + str(y.shape))
第一个if
用于多类分类,第一个elif
用于二进制分类。我不明白为什么是这条线
y = np.searchsorted(self.classes_, y)
需要的。不是lambda x: np.searchsorted(np.unique(x), x)
身份功能吗?
答案 0 :(得分:2)
Isn< lambda x:np.searchsorted(np.unique(x),x)身份函数?
仅当y
仅包含0
和1
时。调用这些函数可确保最终y
仅包含0
和1
,与用于表示二进制类的存在或不存在的符号无关;例如,某些输入可能会使用-1
和1
来代替,或其他内容。
我不认为,正如你所说,条件的第一个分支是多类问题,第二个分支是二元问题。我认为第二个分支也可以用于多类问题,其中类表示为数字,而不是单热编码。在这种情况下,再次,这种预处理将允许您为类使用任意符号(例如任意非顺序正和负整数)和" translate"它们进入范围[0, num_classes - 1]
。