Refactor parsing of Patsy formulas
This commit is contained in:
parent
b2aaaabb0e
commit
0391877296
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import patsy
|
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
import statsmodels
|
import statsmodels
|
||||||
import statsmodels.api as sm
|
import statsmodels.api as sm
|
||||||
@ -28,7 +27,7 @@ import itertools
|
|||||||
|
|
||||||
from .bayes_factors import BayesFactor, bayesfactor_afbf
|
from .bayes_factors import BayesFactor, bayesfactor_afbf
|
||||||
from .sig_tests import FTestResult
|
from .sig_tests import FTestResult
|
||||||
from .utils import Estimate, check_nan, fmt_p
|
from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category
|
||||||
|
|
||||||
def vif(df, formula=None, nan_policy='warn'):
|
def vif(df, formula=None, nan_policy='warn'):
|
||||||
"""
|
"""
|
||||||
@ -59,22 +58,6 @@ def vif(df, formula=None, nan_policy='warn'):
|
|||||||
|
|
||||||
return pd.Series(vifs)
|
return pd.Series(vifs)
|
||||||
|
|
||||||
def cols_for_formula(formula):
|
|
||||||
"""Return the columns corresponding to the Patsy formula"""
|
|
||||||
|
|
||||||
model_desc = patsy.ModelDesc.from_formula(formula)
|
|
||||||
cols = set()
|
|
||||||
for term in model_desc.rhs_termlist:
|
|
||||||
for factor in term.factors:
|
|
||||||
name = factor.name()
|
|
||||||
if '(' in name:
|
|
||||||
# FIXME: Is there a better way of doing this?
|
|
||||||
name = name[name.index('(')+1:name.index(')')]
|
|
||||||
|
|
||||||
cols.add(name)
|
|
||||||
|
|
||||||
return list(cols)
|
|
||||||
|
|
||||||
# ----------
|
# ----------
|
||||||
# Regression
|
# Regression
|
||||||
|
|
||||||
@ -408,15 +391,13 @@ def regress(
|
|||||||
term = raw_name[:raw_name.index('[T.')]
|
term = raw_name[:raw_name.index('[T.')]
|
||||||
category = raw_name[raw_name.index('[T.')+3:raw_name.index(']')]
|
category = raw_name[raw_name.index('[T.')+3:raw_name.index(']')]
|
||||||
|
|
||||||
|
patsy_factor = term
|
||||||
if term.startswith('C('):
|
if term.startswith('C('):
|
||||||
term = term[2:-1]
|
term = term[2:-1]
|
||||||
|
|
||||||
# Add a new categorical term if not exists
|
# Add a new categorical term if not exists
|
||||||
if term not in terms:
|
if term not in terms:
|
||||||
# Try to guess the ref_category
|
ref_category = formula_factor_ref_category(formula, df, patsy_factor)
|
||||||
# FIXME: This is a VERY brittle implementation!!
|
|
||||||
ref_category = sorted(df[term].unique())[0]
|
|
||||||
|
|
||||||
terms[term] = CategoricalTerm({}, ref_category)
|
terms[term] = CategoricalTerm({}, ref_category)
|
||||||
|
|
||||||
terms[term].categories[category] = SingleTerm(raw_name, beta, result.pvalues[raw_name])
|
terms[term].categories[category] = SingleTerm(raw_name, beta, result.pvalues[raw_name])
|
||||||
|
46
yli/utils.py
46
yli/utils.py
@ -16,9 +16,13 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import patsy
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Data cleaning and validation
|
||||||
|
|
||||||
def check_nan(df, nan_policy):
|
def check_nan(df, nan_policy):
|
||||||
"""Check df against nan_policy and return cleaned input"""
|
"""Check df against nan_policy and return cleaned input"""
|
||||||
|
|
||||||
@ -53,6 +57,9 @@ def as_2groups(df, data, group):
|
|||||||
|
|
||||||
return group1, data1, group2, data2
|
return group1, data1, group2, data2
|
||||||
|
|
||||||
|
# ----------
|
||||||
|
# Formatting
|
||||||
|
|
||||||
def do_fmt_p(p):
|
def do_fmt_p(p):
|
||||||
"""Return sign and formatted p value"""
|
"""Return sign and formatted p value"""
|
||||||
|
|
||||||
@ -91,6 +98,9 @@ def fmt_p(p, *, html, nospace=False):
|
|||||||
|
|
||||||
return pfmt
|
return pfmt
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# General result-related classes
|
||||||
|
|
||||||
class Estimate:
|
class Estimate:
|
||||||
"""A point estimate and surrounding confidence interval"""
|
"""A point estimate and surrounding confidence interval"""
|
||||||
|
|
||||||
@ -116,3 +126,39 @@ class Estimate:
|
|||||||
|
|
||||||
def exp(self):
|
def exp(self):
|
||||||
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
|
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# Patsy formula manipulation
|
||||||
|
|
||||||
|
def cols_for_formula(formula):
|
||||||
|
"""Return the columns corresponding to the Patsy formula"""
|
||||||
|
|
||||||
|
# Parse the formula
|
||||||
|
model_desc = patsy.ModelDesc.from_formula(formula)
|
||||||
|
|
||||||
|
# Get the columns
|
||||||
|
cols = set()
|
||||||
|
for term in model_desc.rhs_termlist:
|
||||||
|
for factor in term.factors:
|
||||||
|
name = factor.name()
|
||||||
|
if '(' in name:
|
||||||
|
# FIXME: Is there a better way of doing this?
|
||||||
|
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
|
||||||
|
name = name[name.index('(')+1:name.index(')')]
|
||||||
|
|
||||||
|
cols.add(name)
|
||||||
|
|
||||||
|
return list(cols)
|
||||||
|
|
||||||
|
def formula_factor_ref_category(formula, df, factor):
|
||||||
|
"""Get the reference category for a term in a Patsy formula referring to a categorical factor"""
|
||||||
|
|
||||||
|
# Parse the formula
|
||||||
|
design_info = patsy.dmatrix(formula, df).design_info
|
||||||
|
|
||||||
|
# Get the corresponding factor_info
|
||||||
|
factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor)
|
||||||
|
|
||||||
|
# FIXME: This does not handle complex expressions, e.g. C(x, Treatment(y))
|
||||||
|
categories = factor_info.categories
|
||||||
|
return categories[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user