Refactor parsing of Patsy formulas

This commit is contained in:
RunasSudo 2022-10-15 00:54:46 +11:00
parent b2aaaabb0e
commit 0391877296
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
2 changed files with 49 additions and 22 deletions

View File

@ -16,7 +16,6 @@
import numpy as np
import pandas as pd
import patsy
from scipy import stats
import statsmodels
import statsmodels.api as sm
@ -28,7 +27,7 @@ import itertools
from .bayes_factors import BayesFactor, bayesfactor_afbf
from .sig_tests import FTestResult
from .utils import Estimate, check_nan, fmt_p
from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category
def vif(df, formula=None, nan_policy='warn'):
"""
@ -59,22 +58,6 @@ def vif(df, formula=None, nan_policy='warn'):
return pd.Series(vifs)
def cols_for_formula(formula):
"""Return the columns corresponding to the Patsy formula"""
model_desc = patsy.ModelDesc.from_formula(formula)
cols = set()
for term in model_desc.rhs_termlist:
for factor in term.factors:
name = factor.name()
if '(' in name:
# FIXME: Is there a better way of doing this?
name = name[name.index('(')+1:name.index(')')]
cols.add(name)
return list(cols)
# ----------
# Regression
@ -408,15 +391,13 @@ def regress(
term = raw_name[:raw_name.index('[T.')]
category = raw_name[raw_name.index('[T.')+3:raw_name.index(']')]
patsy_factor = term
if term.startswith('C('):
term = term[2:-1]
# Add a new categorical term if not exists
if term not in terms:
# Try to guess the ref_category
# FIXME: This is a VERY brittle implementation!!
ref_category = sorted(df[term].unique())[0]
ref_category = formula_factor_ref_category(formula, df, patsy_factor)
terms[term] = CategoricalTerm({}, ref_category)
terms[term].categories[category] = SingleTerm(raw_name, beta, result.pvalues[raw_name])

View File

@ -16,9 +16,13 @@
import numpy as np
import pandas as pd
import patsy
import warnings
# ----------------------------
# Data cleaning and validation
def check_nan(df, nan_policy):
"""Check df against nan_policy and return cleaned input"""
@ -53,6 +57,9 @@ def as_2groups(df, data, group):
return group1, data1, group2, data2
# ----------
# Formatting
def do_fmt_p(p):
"""Return sign and formatted p value"""
@ -91,6 +98,9 @@ def fmt_p(p, *, html, nospace=False):
return pfmt
# ------------------------------
# General result-related classes
class Estimate:
"""A point estimate and surrounding confidence interval"""
@ -116,3 +126,39 @@ class Estimate:
def exp(self):
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
# --------------------------
# Patsy formula manipulation
def cols_for_formula(formula):
"""Return the columns corresponding to the Patsy formula"""
# Parse the formula
model_desc = patsy.ModelDesc.from_formula(formula)
# Get the columns
cols = set()
for term in model_desc.rhs_termlist:
for factor in term.factors:
name = factor.name()
if '(' in name:
# FIXME: Is there a better way of doing this?
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
name = name[name.index('(')+1:name.index(')')]
cols.add(name)
return list(cols)
def formula_factor_ref_category(formula, df, factor):
"""Get the reference category for a term in a Patsy formula referring to a categorical factor"""
# Parse the formula
design_info = patsy.dmatrix(formula, df).design_info
# Get the corresponding factor_info
factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor)
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
categories = factor_info.categories
return categories[0]