scipy-yli/yli/utils.py

244 lines
6.9 KiB
Python
Raw Normal View History

2022-10-13 12:53:18 +11:00
# scipy-yli: Helpful SciPy utilities and recipes
# Copyright © 2022 Lee Yingtong Li (RunasSudo)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import numpy as np
import pandas as pd
2022-10-15 00:54:46 +11:00
import patsy
2022-10-13 12:53:18 +11:00
import warnings
2022-10-15 00:54:46 +11:00
# ----------------------------
# Data cleaning and validation
2022-10-13 12:53:18 +11:00
def check_nan(df, nan_policy):
"""Check df against nan_policy and return cleaned input"""
if nan_policy == 'raise':
if pd.isna(df).any(axis=None):
raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore')
elif nan_policy == 'warn':
df_cleaned = df.dropna()
if len(df_cleaned) < len(df):
warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned)))
return df_cleaned
elif nan_policy == 'omit':
return df.dropna()
else:
raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
def as_2groups(df, data, group):
"""Group the data by the given variable, ensuring only 2 groups"""
# Get groupings
groups = list(df.groupby(group).groups.items())
# Ensure only 2 groups to compare
if len(groups) != 2:
raise Exception('Got {} values for {}, expected 2'.format(len(groups), group))
# Get 2 groups
group1 = groups[0][0]
data1 = df.loc[groups[0][1], data]
group2 = groups[1][0]
data2 = df.loc[groups[1][1], data]
return group1, data1, group2, data2
2022-10-15 00:54:46 +11:00
# ----------
# Formatting
2022-10-13 12:53:18 +11:00
def do_fmt_p(p):
"""Return sign and formatted p value"""
if p < 0.001:
return '<', '0.001*'
elif p < 0.0095:
return None, '{:.3f}*'.format(p)
elif p < 0.045:
return None, '{:.2f}*'.format(p)
elif p < 0.05:
return None, '{:.3f}*'.format(p) # 3dps to show significance
elif p < 0.055:
return None, '{:.3f}'.format(p) # 3dps to show non-significance
elif p < 0.095:
return None, '{:.2f}'.format(p)
else:
return None, '{:.1f}'.format(p)
2022-10-14 14:48:26 +11:00
def fmt_p(p, *, html, nospace=False):
"""Format p value"""
2022-10-13 12:53:18 +11:00
sign, fmt = do_fmt_p(p)
if sign is not None:
if nospace:
2022-10-14 14:48:26 +11:00
pfmt = sign + fmt # e.g. "<0.001"
2022-10-13 12:53:18 +11:00
else:
2022-10-14 14:48:26 +11:00
pfmt = sign + ' ' + fmt # e.g. "< 0.001"
2022-10-13 12:53:18 +11:00
else:
if nospace:
2022-10-14 14:48:26 +11:00
pfmt = fmt # e.g. "0.05"
2022-10-13 12:53:18 +11:00
else:
2022-10-14 14:48:26 +11:00
pfmt = '= ' + fmt # e.g. "= 0.05"
if html:
pfmt = pfmt.replace('<', '&lt;')
2022-10-13 12:53:18 +11:00
2022-10-14 14:48:26 +11:00
return pfmt
2022-10-13 12:53:18 +11:00
2022-10-15 00:54:46 +11:00
# ------------------------------
# General result-related classes
2022-10-13 12:53:18 +11:00
class Estimate:
"""A point estimate and surrounding confidence interval"""
def __init__(self, point, ci_lower, ci_upper):
self.point = point
self.ci_lower = ci_lower
self.ci_upper = ci_upper
def _repr_html_(self):
return self.summary()
def summary(self):
return '{:.2f} ({:.2f}{:.2f})'.format(self.point, self.ci_lower, self.ci_upper)
def __neg__(self):
return Estimate(-self.point, -self.ci_upper, -self.ci_lower)
def __abs__(self):
if self.point < 0:
return -self
else:
return self
def exp(self):
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
2022-10-15 00:54:46 +11:00
# --------------------------
# Patsy formula manipulation
def cols_for_formula(formula, df):
2022-10-15 00:54:46 +11:00
"""Return the columns corresponding to the Patsy formula"""
# Parse the formula
model_desc = patsy.ModelDesc.from_formula(formula)
# Get the columns
cols = set()
for term in model_desc.rhs_termlist:
for factor in term.factors:
name = factor.name()
if name.startswith('C('):
# Contrasts expression
# Get the corresponding factor_info
factor_info = formula_get_factor_info(formula, df, name)
# Evaluate the factor
categorical_box = factor_info.factor.eval(factor_info.state, df)
# Get the column name
name = categorical_box.data.name
2022-10-15 00:54:46 +11:00
cols.add(name)
return list(cols)
def formula_get_factor_info(formula, df, factor):
"""Get the FactorInfo for a factor in a Patsy formula"""
2022-10-15 00:54:46 +11:00
# Parse the formula
design_info = patsy.dmatrix(formula, df).design_info
# Get the corresponding factor_info
factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor)
return factor_info
def formula_factor_ref_category(formula, df, factor):
"""Get the reference category for a term in a Patsy formula referring to a categorical factor"""
if '(' in factor and not factor.startswith('C('):
raise Exception('Attempted to get reference category for unknown expression type "{}"'.format(factor))
# Get the factor_info
factor_info = formula_get_factor_info(formula, df, factor)
if '(' not in factor:
# C(...) is not specified, so must be default
return factor_info.categories[0]
# Evaluate the factor
categorical_box = factor_info.factor.eval(factor_info.state, df)
if categorical_box.contrast is None or categorical_box.contrast is patsy.Treatment:
# Default Treatment contrast with default reference group: first category
return factor_info.categories[0]
if isinstance(categorical_box.contrast, patsy.Treatment):
if categorical_box.contrast.reference is None:
# Default reference group: first category
return factor_info.categories[0]
# Specified reference group
return categorical_box.contrast.reference
raise Exception('Attempted to get reference category for unknown contrast type {}'.format(categorical_box.contrast.__class__.__name__))
def parse_patsy_term(formula, df, term):
"""
Parse a Patsy term into its component parts
Returns: factor, column, contrast
e.g. "C(x, Treatment(y))[T.z]" -> "C(x, Treatment(y))", "x", "z"
"""
if '(' not in term:
if '[' in term:
if '[T.' not in term:
raise Exception('Attempted to parse term for unknown contrast type "{}"'.format(term))
# Treatment contrast term
factor = term[:term.index('[T.')]
contrast = term[term.index('[T.')+3:term.index(']')]
return factor, factor, contrast
else:
# Nothing special
return term, term, None
# Term contains '('
if not term.startswith('C('):
raise Exception('Attempted to parse term for unknown expression type "{}"'.format(term))
if '[' in term:
if '[T.' not in term:
raise Exception('Attempted to parse term for unknown contrast type "{}"'.format(term))
# Treatment contrast term
factor = term[:term.index('[T.')]
contrast = term[term.index('[T.')+3:term.index(']')]
else:
# Not a treatment contrast (I think this is impossible?)
raise Exception('Attempted to parse unsupported contrast-like term with no contrasts')
factor_inner = factor[factor.index('(')+1:factor.rindex(')')]
if ',' in factor_inner:
column = factor_inner[:factor_inner.index(',')]
else:
column = factor_inner
2022-10-15 00:54:46 +11:00
return factor, column, contrast