Sklearn:克隆自定义转换将替换__init__词典中的值

时间:2019-07-24 19:48:15

标签: python dictionary scikit-learn copy cloning

让我们说我们有一个包含一些参数的自定义转换A。实例化A时,将在init中设置这些参数。

当我们克隆A时(例如cross_val_score中的情况),参数被成功复制。

但是,如果将参数发送到字典等结构,则克隆会将其替换为None。

在None不会导致错误的情况下,这会产生一个无提示错误,因为A的克隆版本将运行,产生的结果与原始版本不同(这是我在第一个问题中遇到的方式)地点)。

以下是完全可复制的示例(sklearn version '0.20.3')。

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import clone


class MyTransformA(BaseEstimator, TransformerMixin):

    def __init__(self, n_cols_to_keep):
        self.cols_to_keep_dict = {'n_cols': n_cols_to_keep}  

    def fit(self, X, *_):
        return self 

    def transform(self, X, *_):
        return X


class MyTransformB(BaseEstimator, TransformerMixin):

    def __init__(self, n_cols_to_keep):
        self.n_cols_to_keep = n_cols_to_keep  # <--- this time we save the input immediately 
        self.cols_to_keep_dict = {'n_cols': self.n_cols_to_keep}  

    def fit(self, X, *_):
        return self 

    def transform(self, X, *_):
        return X

my_transform_a = MyTransformA(n_cols_to_keep=5)
my_transform_a_clone = clone(my_transform_a)

my_transform_b = MyTransformB(n_cols_to_keep=5)
my_transform_b_clone = clone(my_transform_b)

print('Using MyTransformA:')
print('  my_transform_a.cols_to_keep_dict:        %s' % str(my_transform_a.cols_to_keep_dict))
print('  my_transform_a_clone.cols_to_keep_dict:  %s  <------ ?' % str(my_transform_a_clone.cols_to_keep_dict))

print('\nUsing MyTransformB:')
print('  my_transform_b.cols_to_keep_dict:        %s' % str(my_transform_b.cols_to_keep_dict))
print('  my_transform_b_clone.cols_to_keep_dict): %s' % str(my_transform_b_clone.cols_to_keep_dict))
预期成绩
Using MyTransformA:
  my_transform_a.cols_to_keep_dict:        ('n_cols', 5)
  my_transform_a_clone.cols_to_keep_dict:  ('n_cols', 5)  <------ Does not happen

Using MyTransformB:
  my_transform_b.cols_to_keep_dict:        {'n_cols': 5}
  my_transform_b_clone.cols_to_keep_dict): {'n_cols': 5}
实际结果
Using MyTransformA:
  my_transform_a.cols_to_keep_dict:        ('n_cols', 5)
  my_transform_a_clone.cols_to_keep_dict:  ('n_cols', None)  <------ ?

Using MyTransformB:
  my_transform_b.cols_to_keep_dict:        {'n_cols': 5}
  my_transform_b_clone.cols_to_keep_dict): {'n_cols': 5}

0 个答案:

没有答案