我对Python,Stats和使用DS库比较陌生,我的要求是对具有n列的数据集运行多重共线性测试,并确保列/变量具有VIF>完全放弃了5个。
我找到了一个代码,
from statsmodels.stats.outliers_influence import variance_inflation_factor
def calculate_vif_(X, thresh=5.0):
variables = range(X.shape[1])
tmp = range(X[variables].shape[1])
print(tmp)
dropped=True
while dropped:
dropped=False
vif = [variance_inflation_factor(X[variables].values, ix) for ix in range(X[variables].shape[1])]
maxloc = vif.index(max(vif))
if max(vif) > thresh:
print('dropping \'' + X[variables].columns[maxloc] + '\' at index: ' + str(maxloc))
del variables[maxloc]
dropped=True
print('Remaining variables:')
print(X.columns[variables])
return X[variables]
但是,我不清楚,我是否应该在X参数的位置完全传递数据集?如果是,则无效。
请帮忙!
答案 0 :(得分:1)
我也遇到类似的问题。我通过更改variables
的定义方式并找到另一种删除元素的方法来修复它。
以下脚本适用于Anaconda 5.0.1和Python 3.6(撰写本文时的最新版本)。
import numpy as np
import pandas as pd
import time
from statsmodels.stats.outliers_influence import variance_inflation_factor
from joblib import Parallel, delayed
# Defining the function that you will run later
def calculate_vif_(X, thresh=5.0):
variables = [X.columns[i] for i in range(X.shape[1])]
dropped=True
while dropped:
dropped=False
print(len(variables))
vif = Parallel(n_jobs=-1,verbose=5)(delayed(variance_inflation_factor)(X[variables].values, ix) for ix in range(len(variables)))
maxloc = vif.index(max(vif))
if max(vif) > thresh:
print(time.ctime() + ' dropping \'' + X[variables].columns[maxloc] + '\' at index: ' + str(maxloc))
variables.pop(maxloc)
dropped=True
print('Remaining variables:')
print([variables])
return X[[i for i in variables]]
X = df[feature_list] # Selecting your data
X2 = calculate_vif_(X,5) # Actually running the function
如果您有许多功能,则需要很长时间才能运行。所以我做了另一个更改,让它有并行工作,以防你有多个CPU可用。
享受!
答案 1 :(得分:1)
我调整了代码并通过以下代码设法实现了所需的结果,并进行了一些异常处理 -
def multicollinearity_check(X, thresh=5.0):
data_type = X.dtypes
# print(type(data_type))
int_cols = \
X.select_dtypes(include=['int', 'int16', 'int32', 'int64', 'float', 'float16', 'float32', 'float64']).shape[1]
total_cols = X.shape[1]
try:
if int_cols != total_cols:
raise Exception('All the columns should be integer or float, for multicollinearity test.')
else:
variables = list(range(X.shape[1]))
dropped = True
print('''\n\nThe VIF calculator will now iterate through the features and calculate their respective values.
It shall continue dropping the highest VIF features until all the features have VIF less than the threshold of 5.\n\n''')
while dropped:
dropped = False
vif = [variance_inflation_factor(X.iloc[:, variables].values, ix) for ix in variables]
print('\n\nvif is: ', vif)
maxloc = vif.index(max(vif))
if max(vif) > thresh:
print('dropping \'' + X.iloc[:, variables].columns[maxloc] + '\' at index: ' + str(maxloc))
# del variables[maxloc]
X.drop(X.columns[variables[maxloc]], 1, inplace=True)
variables = list(range(X.shape[1]))
dropped = True
print('\n\nRemaining variables:\n')
print(X.columns[variables])
# return X.iloc[:,variables]
return X
except Exception as e:
print('Error caught: ', e)
答案 2 :(得分:0)
首先,感谢 @DanSan 提出了多重共线性计算中的并行化。现在,对于形状为 (22500, 71) 的多维数据集,我的计算时间至少提高了 50%。但是我在我正在处理的数据集上遇到了一个有趣的挑战。数据集实际上包含一些分类列,我有 Binary encoded 使用 Category-encoders,因此某些列只有 1 个唯一值。对于此类列,VIF 的值是非有限的或 NaN !
以下快照显示了我的数据集中 71 个二进制编码列中某些列的 VIF 值:
在这些情况下,根据我的痛苦经验,使用 @Aakash Basu 和 @DanSan 的代码后剩余的列数有时可能会取决于数据集中列的顺序,因为列是根据最大 VIF 值线性删除的。而且只有一个值的列对于任何机器学习模型都有些愚蠢,因为它会强行将偏见强加到系统中!
为了处理这个问题,您可以使用以下更新的代码:
from joblib import Parallel, delayed
from statsmodels.stats.outliers_influence import variance_inflation_factor
def removeMultiColl(data, vif_threshold = 5.0):
for i in data.columns:
if data[i].nunique() == 1:
print(f"Dropping {i} due to just 1 unique value")
data.drop(columns = i, inplace = True)
drop = True
col_list = list(data.columns)
while drop == True:
drop = False
vif_list = Parallel(n_jobs = -1, verbose = 5)(delayed(variance_inflation_factor)(data[col_list].values, i) for i in range(data[col_list].shape[1]))
max_index = vif_list.index(max(vif_list))
if vif_list[max_index] > vif_threshold:
print(f"Dropping column : {col_list[max_index]} at index - {max_index}")
del col_list[max_index]
drop = True
print("Remaining columns :\n", list(data[col_list].columns))
return data[col_list]
祝你好运!