From 039187729694eaa0e7fa7cd9e3446834528aef1f Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sat, 15 Oct 2022 00:54:46 +1100 Subject: [PATCH] Refactor parsing of Patsy formulas --- yli/regress.py | 25 +++---------------------- yli/utils.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/yli/regress.py b/yli/regress.py index 34aa729..777e83d 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -16,7 +16,6 @@ import numpy as np import pandas as pd -import patsy from scipy import stats import statsmodels import statsmodels.api as sm @@ -28,7 +27,7 @@ import itertools from .bayes_factors import BayesFactor, bayesfactor_afbf from .sig_tests import FTestResult -from .utils import Estimate, check_nan, fmt_p +from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category def vif(df, formula=None, nan_policy='warn'): """ @@ -59,22 +58,6 @@ def vif(df, formula=None, nan_policy='warn'): return pd.Series(vifs) -def cols_for_formula(formula): - """Return the columns corresponding to the Patsy formula""" - - model_desc = patsy.ModelDesc.from_formula(formula) - cols = set() - for term in model_desc.rhs_termlist: - for factor in term.factors: - name = factor.name() - if '(' in name: - # FIXME: Is there a better way of doing this? - name = name[name.index('(')+1:name.index(')')] - - cols.add(name) - - return list(cols) - # ---------- # Regression @@ -408,15 +391,13 @@ def regress( term = raw_name[:raw_name.index('[T.')] category = raw_name[raw_name.index('[T.')+3:raw_name.index(']')] + patsy_factor = term if term.startswith('C('): term = term[2:-1] # Add a new categorical term if not exists if term not in terms: - # Try to guess the ref_category - # FIXME: This is a VERY brittle implementation!! - ref_category = sorted(df[term].unique())[0] - + ref_category = formula_factor_ref_category(formula, df, patsy_factor) terms[term] = CategoricalTerm({}, ref_category) terms[term].categories[category] = SingleTerm(raw_name, beta, result.pvalues[raw_name]) diff --git a/yli/utils.py b/yli/utils.py index 64bc842..1d44e7d 100644 --- a/yli/utils.py +++ b/yli/utils.py @@ -16,9 +16,13 @@ import numpy as np import pandas as pd +import patsy import warnings +# ---------------------------- +# Data cleaning and validation + def check_nan(df, nan_policy): """Check df against nan_policy and return cleaned input""" @@ -53,6 +57,9 @@ def as_2groups(df, data, group): return group1, data1, group2, data2 +# ---------- +# Formatting + def do_fmt_p(p): """Return sign and formatted p value""" @@ -91,6 +98,9 @@ def fmt_p(p, *, html, nospace=False): return pfmt +# ------------------------------ +# General result-related classes + class Estimate: """A point estimate and surrounding confidence interval""" @@ -116,3 +126,39 @@ class Estimate: def exp(self): return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper)) + +# -------------------------- +# Patsy formula manipulation + +def cols_for_formula(formula): + """Return the columns corresponding to the Patsy formula""" + + # Parse the formula + model_desc = patsy.ModelDesc.from_formula(formula) + + # Get the columns + cols = set() + for term in model_desc.rhs_termlist: + for factor in term.factors: + name = factor.name() + if '(' in name: + # FIXME: Is there a better way of doing this? + # FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y)) + name = name[name.index('(')+1:name.index(')')] + + cols.add(name) + + return list(cols) + +def formula_factor_ref_category(formula, df, factor): + """Get the reference category for a term in a Patsy formula referring to a categorical factor""" + + # Parse the formula + design_info = patsy.dmatrix(formula, df).design_info + + # Get the corresponding factor_info + factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor) + + # FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y)) + categories = factor_info.categories + return categories[0]