diff --git a/yli/regress.py b/yli/regress.py index 5708cc7..5ab7946 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -28,8 +28,16 @@ import itertools from .utils import Estimate, check_nan, fmt_p_html, fmt_p_text -def vif(df, nan_policy='warn'): - """Calculate the variance inflation factor for each variable in df""" +def vif(df, formula=None, nan_policy='warn'): + """ + Calculate the variance inflation factor for each variable in df + + formula: If specified, calculate the VIF only for the variables in the formula + """ + + if formula: + # Only consider columns in the formula + df = df[cols_for_formula(formula)] # Check for/clean NaNs df = check_nan(df, nan_policy)