Implement vif
This commit is contained in:
parent
987520e700
commit
51008296c2
51
tests/test_vif.py
Normal file
51
tests/test_vif.py
Normal file
@ -0,0 +1,51 @@
|
||||
# scipy-yli: Helpful SciPy utilities and recipes
|
||||
# Copyright © 2022 Lee Yingtong Li (RunasSudo)
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
from pytest import approx
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import yli
|
||||
|
||||
def test_vif_ol13_5():
|
||||
"""Compare yli.vif for Ott & Longnecker (2016) chapter 13.5"""
|
||||
|
||||
df = pd.DataFrame({
|
||||
'C': [460.05, 452.99, 443.22, 652.32, 642.23, 345.39, 272.37, 317.21, 457.12, 690.19, 350.63, 402.59, 412.18, 495.58, 394.36, 423.32, 712.27, 289.66, 881.24, 490.88, 567.79, 665.99, 621.45, 608.8, 473.64, 697.14, 207.51, 288.48, 284.88, 280.36, 217.38, 270.71],
|
||||
'D': [68.58, 67.33, 67.33, 68, 68, 67.92, 68.17, 68.42, 68.42, 68.33, 68.58, 68.75, 68.42, 68.92, 68.92, 68.42, 69.5, 68.42, 69.17, 68.92, 68.75, 70.92, 69.67, 70.08, 70.42, 71.08, 67.25, 67.17, 67.83, 67.83, 67.25, 67.83],
|
||||
'T1': [14, 10, 10, 11, 11, 13, 12, 14, 15, 12, 12, 13, 15, 17, 13, 11, 18, 15, 15, 16, 11, 22, 16, 19, 19, 20, 13, 9, 12, 12, 13, 7],
|
||||
'T2': [46, 73, 85, 67, 78, 51, 50, 59, 55, 71, 64, 47, 62, 52, 65, 67, 60, 76, 67, 59, 70, 57, 59, 58, 44, 57, 63, 48, 63, 71, 72, 80],
|
||||
'S': [687, 1065, 1065, 1065, 1065, 514, 822, 457, 822, 792, 560, 790, 530, 1050, 850, 778, 845, 530, 1090, 1050, 913, 828, 786, 821, 538, 1130, 745, 821, 886, 886, 745, 886],
|
||||
'PR': [0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
'NE': [1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'CT': [0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0],
|
||||
'BW': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1],
|
||||
'N': [14, 1, 1, 12, 12, 3, 5, 1, 5, 2, 3, 6, 2, 7, 16, 3, 17, 2, 1, 8, 15, 20, 18, 3, 19, 21, 8, 7, 11, 11, 8, 11],
|
||||
'PT': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
|
||||
})
|
||||
|
||||
vifs = yli.vif(df[['D', 'T1', 'T2', 'S', 'PR', 'NE', 'CT', 'BW', 'N', 'PT']])
|
||||
|
||||
assert vifs['D'] == approx(8.31830, abs=0.00001)
|
||||
assert vifs['T1'] == approx(6.08159, abs=0.00001)
|
||||
assert vifs['T2'] == approx(2.45712, abs=0.00001)
|
||||
assert vifs['S'] == approx(1.26727, abs=0.00001)
|
||||
assert vifs['PR'] == approx(1.66568, abs=0.00001)
|
||||
assert vifs['NE'] == approx(1.30924, abs=0.00001)
|
||||
assert vifs['CT'] == approx(1.32422, abs=0.00001)
|
||||
assert vifs['BW'] == approx(1.91292, abs=0.00001)
|
||||
assert vifs['N'] == approx(2.64429, abs=0.00001)
|
||||
assert vifs['PT'] == approx(2.88092, abs=0.00001)
|
@ -16,7 +16,7 @@
|
||||
|
||||
from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist
|
||||
from .fs import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
|
||||
from .regress import regress
|
||||
from .regress import regress, vif
|
||||
from .sig_tests import chi2, mannwhitney, ttest_ind
|
||||
|
||||
def reload_me():
|
||||
|
@ -15,16 +15,39 @@
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import patsy
|
||||
from scipy import stats
|
||||
import statsmodels.api as sm
|
||||
from statsmodels.iolib.table import SimpleTable
|
||||
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
||||
|
||||
from datetime import datetime
|
||||
import itertools
|
||||
|
||||
from .utils import Estimate, check_nan, fmt_p_html, fmt_p_text
|
||||
|
||||
def vif(df, nan_policy='warn'):
|
||||
"""Calculate the variance inflation factor for each variable in df"""
|
||||
|
||||
# Check for/clean NaNs
|
||||
df = check_nan(df, nan_policy)
|
||||
|
||||
# Convert all to float64 otherwise statsmodels chokes with "ufunc 'isfinite' not supported for the input types ..."
|
||||
df = pd.get_dummies(df, drop_first=True) # Convert categorical dtypes
|
||||
df = df.astype('float64') # Convert all other dtypes
|
||||
|
||||
# Add intercept column
|
||||
orig_columns = list(df.columns)
|
||||
df['Intercept'] = [1] * len(df)
|
||||
|
||||
vifs = {}
|
||||
|
||||
for i, col in enumerate(orig_columns):
|
||||
vifs[col] = variance_inflation_factor(df, i)
|
||||
|
||||
return pd.Series(vifs)
|
||||
|
||||
def cols_for_formula(formula):
|
||||
"""Return the columns corresponding to the Patsy formula"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user