Refactor parsing of Patsy formulas

This commit is contained in:
RunasSudo 2022-10-15 00:54:46 +11:00
parent b2aaaabb0e
commit 0391877296
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
2 changed files with 49 additions and 22 deletions

View File

@ -16,7 +16,6 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import patsy
from scipy import stats from scipy import stats
import statsmodels import statsmodels
import statsmodels.api as sm import statsmodels.api as sm
@ -28,7 +27,7 @@ import itertools
from .bayes_factors import BayesFactor, bayesfactor_afbf from .bayes_factors import BayesFactor, bayesfactor_afbf
from .sig_tests import FTestResult from .sig_tests import FTestResult
from .utils import Estimate, check_nan, fmt_p from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category
def vif(df, formula=None, nan_policy='warn'): def vif(df, formula=None, nan_policy='warn'):
""" """
@ -59,22 +58,6 @@ def vif(df, formula=None, nan_policy='warn'):
return pd.Series(vifs) return pd.Series(vifs)
def cols_for_formula(formula):
"""Return the columns corresponding to the Patsy formula"""
model_desc = patsy.ModelDesc.from_formula(formula)
cols = set()
for term in model_desc.rhs_termlist:
for factor in term.factors:
name = factor.name()
if '(' in name:
# FIXME: Is there a better way of doing this?
name = name[name.index('(')+1:name.index(')')]
cols.add(name)
return list(cols)
# ---------- # ----------
# Regression # Regression
@ -408,15 +391,13 @@ def regress(
term = raw_name[:raw_name.index('[T.')] term = raw_name[:raw_name.index('[T.')]
category = raw_name[raw_name.index('[T.')+3:raw_name.index(']')] category = raw_name[raw_name.index('[T.')+3:raw_name.index(']')]
patsy_factor = term
if term.startswith('C('): if term.startswith('C('):
term = term[2:-1] term = term[2:-1]
# Add a new categorical term if not exists # Add a new categorical term if not exists
if term not in terms: if term not in terms:
# Try to guess the ref_category ref_category = formula_factor_ref_category(formula, df, patsy_factor)
# FIXME: This is a VERY brittle implementation!!
ref_category = sorted(df[term].unique())[0]
terms[term] = CategoricalTerm({}, ref_category) terms[term] = CategoricalTerm({}, ref_category)
terms[term].categories[category] = SingleTerm(raw_name, beta, result.pvalues[raw_name]) terms[term].categories[category] = SingleTerm(raw_name, beta, result.pvalues[raw_name])

View File

@ -16,9 +16,13 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import patsy
import warnings import warnings
# ----------------------------
# Data cleaning and validation
def check_nan(df, nan_policy): def check_nan(df, nan_policy):
"""Check df against nan_policy and return cleaned input""" """Check df against nan_policy and return cleaned input"""
@ -53,6 +57,9 @@ def as_2groups(df, data, group):
return group1, data1, group2, data2 return group1, data1, group2, data2
# ----------
# Formatting
def do_fmt_p(p): def do_fmt_p(p):
"""Return sign and formatted p value""" """Return sign and formatted p value"""
@ -91,6 +98,9 @@ def fmt_p(p, *, html, nospace=False):
return pfmt return pfmt
# ------------------------------
# General result-related classes
class Estimate: class Estimate:
"""A point estimate and surrounding confidence interval""" """A point estimate and surrounding confidence interval"""
@ -116,3 +126,39 @@ class Estimate:
def exp(self): def exp(self):
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper)) return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
# --------------------------
# Patsy formula manipulation
def cols_for_formula(formula):
"""Return the columns corresponding to the Patsy formula"""
# Parse the formula
model_desc = patsy.ModelDesc.from_formula(formula)
# Get the columns
cols = set()
for term in model_desc.rhs_termlist:
for factor in term.factors:
name = factor.name()
if '(' in name:
# FIXME: Is there a better way of doing this?
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
name = name[name.index('(')+1:name.index(')')]
cols.add(name)
return list(cols)
def formula_factor_ref_category(formula, df, factor):
"""Get the reference category for a term in a Patsy formula referring to a categorical factor"""
# Parse the formula
design_info = patsy.dmatrix(formula, df).design_info
# Get the corresponding factor_info
factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor)
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
categories = factor_info.categories
return categories[0]