Refactor parsing of Patsy formulas
This commit is contained in:
parent
b2aaaabb0e
commit
0391877296
@ -16,7 +16,6 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import patsy
|
||||
from scipy import stats
|
||||
import statsmodels
|
||||
import statsmodels.api as sm
|
||||
@ -28,7 +27,7 @@ import itertools
|
||||
|
||||
from .bayes_factors import BayesFactor, bayesfactor_afbf
|
||||
from .sig_tests import FTestResult
|
||||
from .utils import Estimate, check_nan, fmt_p
|
||||
from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category
|
||||
|
||||
def vif(df, formula=None, nan_policy='warn'):
|
||||
"""
|
||||
@ -59,22 +58,6 @@ def vif(df, formula=None, nan_policy='warn'):
|
||||
|
||||
return pd.Series(vifs)
|
||||
|
||||
def cols_for_formula(formula):
|
||||
"""Return the columns corresponding to the Patsy formula"""
|
||||
|
||||
model_desc = patsy.ModelDesc.from_formula(formula)
|
||||
cols = set()
|
||||
for term in model_desc.rhs_termlist:
|
||||
for factor in term.factors:
|
||||
name = factor.name()
|
||||
if '(' in name:
|
||||
# FIXME: Is there a better way of doing this?
|
||||
name = name[name.index('(')+1:name.index(')')]
|
||||
|
||||
cols.add(name)
|
||||
|
||||
return list(cols)
|
||||
|
||||
# ----------
|
||||
# Regression
|
||||
|
||||
@ -408,15 +391,13 @@ def regress(
|
||||
term = raw_name[:raw_name.index('[T.')]
|
||||
category = raw_name[raw_name.index('[T.')+3:raw_name.index(']')]
|
||||
|
||||
patsy_factor = term
|
||||
if term.startswith('C('):
|
||||
term = term[2:-1]
|
||||
|
||||
# Add a new categorical term if not exists
|
||||
if term not in terms:
|
||||
# Try to guess the ref_category
|
||||
# FIXME: This is a VERY brittle implementation!!
|
||||
ref_category = sorted(df[term].unique())[0]
|
||||
|
||||
ref_category = formula_factor_ref_category(formula, df, patsy_factor)
|
||||
terms[term] = CategoricalTerm({}, ref_category)
|
||||
|
||||
terms[term].categories[category] = SingleTerm(raw_name, beta, result.pvalues[raw_name])
|
||||
|
46
yli/utils.py
46
yli/utils.py
@ -16,9 +16,13 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import patsy
|
||||
|
||||
import warnings
|
||||
|
||||
# ----------------------------
|
||||
# Data cleaning and validation
|
||||
|
||||
def check_nan(df, nan_policy):
|
||||
"""Check df against nan_policy and return cleaned input"""
|
||||
|
||||
@ -53,6 +57,9 @@ def as_2groups(df, data, group):
|
||||
|
||||
return group1, data1, group2, data2
|
||||
|
||||
# ----------
|
||||
# Formatting
|
||||
|
||||
def do_fmt_p(p):
|
||||
"""Return sign and formatted p value"""
|
||||
|
||||
@ -91,6 +98,9 @@ def fmt_p(p, *, html, nospace=False):
|
||||
|
||||
return pfmt
|
||||
|
||||
# ------------------------------
|
||||
# General result-related classes
|
||||
|
||||
class Estimate:
|
||||
"""A point estimate and surrounding confidence interval"""
|
||||
|
||||
@ -116,3 +126,39 @@ class Estimate:
|
||||
|
||||
def exp(self):
|
||||
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
|
||||
|
||||
# --------------------------
|
||||
# Patsy formula manipulation
|
||||
|
||||
def cols_for_formula(formula):
|
||||
"""Return the columns corresponding to the Patsy formula"""
|
||||
|
||||
# Parse the formula
|
||||
model_desc = patsy.ModelDesc.from_formula(formula)
|
||||
|
||||
# Get the columns
|
||||
cols = set()
|
||||
for term in model_desc.rhs_termlist:
|
||||
for factor in term.factors:
|
||||
name = factor.name()
|
||||
if '(' in name:
|
||||
# FIXME: Is there a better way of doing this?
|
||||
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
|
||||
name = name[name.index('(')+1:name.index(')')]
|
||||
|
||||
cols.add(name)
|
||||
|
||||
return list(cols)
|
||||
|
||||
def formula_factor_ref_category(formula, df, factor):
|
||||
"""Get the reference category for a term in a Patsy formula referring to a categorical factor"""
|
||||
|
||||
# Parse the formula
|
||||
design_info = patsy.dmatrix(formula, df).design_info
|
||||
|
||||
# Get the corresponding factor_info
|
||||
factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor)
|
||||
|
||||
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
|
||||
categories = factor_info.categories
|
||||
return categories[0]
|
||||
|
Loading…
Reference in New Issue
Block a user