2022-10-13 12:53:18 +11:00
|
|
|
# scipy-yli: Helpful SciPy utilities and recipes
|
|
|
|
# Copyright © 2022 Lee Yingtong Li (RunasSudo)
|
|
|
|
#
|
|
|
|
# This program is free software: you can redistribute it and/or modify
|
|
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
|
|
# (at your option) any later version.
|
|
|
|
#
|
|
|
|
# This program is distributed in the hope that it will be useful,
|
|
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
# GNU Affero General Public License for more details.
|
|
|
|
#
|
|
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
2022-10-15 00:54:46 +11:00
|
|
|
import patsy
|
2022-10-13 12:53:18 +11:00
|
|
|
|
|
|
|
import warnings
|
|
|
|
|
2022-10-15 23:11:22 +11:00
|
|
|
from .config import config
|
|
|
|
|
2022-10-15 00:54:46 +11:00
|
|
|
# ----------------------------
|
|
|
|
# Data cleaning and validation
|
|
|
|
|
2022-10-13 12:53:18 +11:00
|
|
|
def check_nan(df, nan_policy):
|
|
|
|
"""Check df against nan_policy and return cleaned input"""
|
|
|
|
|
|
|
|
if nan_policy == 'raise':
|
|
|
|
if pd.isna(df).any(axis=None):
|
|
|
|
raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore')
|
|
|
|
elif nan_policy == 'warn':
|
|
|
|
df_cleaned = df.dropna()
|
|
|
|
if len(df_cleaned) < len(df):
|
|
|
|
warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned)))
|
|
|
|
return df_cleaned
|
|
|
|
elif nan_policy == 'omit':
|
|
|
|
return df.dropna()
|
|
|
|
else:
|
|
|
|
raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
|
|
|
|
|
|
|
|
def as_2groups(df, data, group):
|
|
|
|
"""Group the data by the given variable, ensuring only 2 groups"""
|
|
|
|
|
|
|
|
# Get groupings
|
|
|
|
groups = list(df.groupby(group).groups.items())
|
|
|
|
|
|
|
|
# Ensure only 2 groups to compare
|
|
|
|
if len(groups) != 2:
|
|
|
|
raise Exception('Got {} values for {}, expected 2'.format(len(groups), group))
|
|
|
|
|
|
|
|
# Get 2 groups
|
|
|
|
group1 = groups[0][0]
|
|
|
|
data1 = df.loc[groups[0][1], data]
|
|
|
|
group2 = groups[1][0]
|
|
|
|
data2 = df.loc[groups[1][1], data]
|
|
|
|
|
|
|
|
return group1, data1, group2, data2
|
|
|
|
|
2022-10-15 00:54:46 +11:00
|
|
|
# ----------
|
|
|
|
# Formatting
|
|
|
|
|
2022-10-13 12:53:18 +11:00
|
|
|
def do_fmt_p(p):
|
|
|
|
"""Return sign and formatted p value"""
|
|
|
|
|
2022-10-15 23:11:22 +11:00
|
|
|
if p < 10**-config.pvalue_max_dps:
|
|
|
|
# Smaller than min value
|
|
|
|
return '<', '{:.1g}'.format(10**-config.pvalue_max_dps)
|
|
|
|
|
|
|
|
if p > 1 - 10**-config.pvalue_min_dps:
|
|
|
|
# Larger than max value
|
|
|
|
return '>', '{0:.{dps}f}'.format(1 - 10**-config.pvalue_min_dps, dps=config.pvalue_min_dps)
|
|
|
|
|
|
|
|
if round(p, config.pvalue_min_dps) == config.alpha:
|
|
|
|
# Rounding to pvalue_min_dps makes significance ambiguous
|
|
|
|
|
|
|
|
if round(p, config.pvalue_max_dps) == config.alpha:
|
|
|
|
# Still ambiguous to pvalue_max_dps
|
|
|
|
|
|
|
|
if p < config.alpha:
|
|
|
|
# Significant: round down
|
|
|
|
p = config.alpha - 10**-config.pvalue_max_dps
|
|
|
|
else:
|
|
|
|
# Nonsignificant: round up
|
|
|
|
p = config.alpha + 10**-config.pvalue_max_dps
|
|
|
|
|
2022-10-18 19:23:57 +11:00
|
|
|
return '', '{0:.{dps}f}'.format(p, dps=config.pvalue_max_dps)
|
2022-10-15 23:11:22 +11:00
|
|
|
|
|
|
|
if p < 10**-config.pvalue_min_dps:
|
|
|
|
# Insufficient resolution at pvalue_min_dps
|
|
|
|
# We know from earlier comparison that 1 s.f. fits within pvalue_max_dps
|
2022-10-18 19:23:57 +11:00
|
|
|
return '', '{:.1g}'.format(p)
|
2022-10-15 23:11:22 +11:00
|
|
|
|
|
|
|
# OK to round to pvalue_min_dps
|
2022-10-18 19:23:57 +11:00
|
|
|
return '', '{0:.{dps}f}'.format(p, dps=config.pvalue_min_dps)
|
2022-10-13 12:53:18 +11:00
|
|
|
|
2022-10-18 19:23:57 +11:00
|
|
|
def fmt_p(p, *, html, only_value=False, tabular=False):
|
2022-10-15 23:30:41 +11:00
|
|
|
"""
|
|
|
|
Format p value
|
|
|
|
|
|
|
|
tabular: If true, output in ‘tabular’ format of p values where decimal points align
|
|
|
|
"""
|
2022-10-13 12:53:18 +11:00
|
|
|
|
2022-10-18 19:23:57 +11:00
|
|
|
# FIXME: Make only_value and tabular enums
|
|
|
|
|
2022-10-13 12:53:18 +11:00
|
|
|
sign, fmt = do_fmt_p(p)
|
2022-10-15 23:11:22 +11:00
|
|
|
|
2022-10-15 23:30:41 +11:00
|
|
|
# Strip leading zero if required
|
2022-10-15 23:11:22 +11:00
|
|
|
if not config.pvalue_leading_zero:
|
|
|
|
fmt = fmt.lstrip('0')
|
2022-10-15 23:30:41 +11:00
|
|
|
|
2022-10-18 19:23:57 +11:00
|
|
|
# Check if significant
|
2022-10-15 23:11:22 +11:00
|
|
|
if p < config.alpha:
|
2022-10-18 19:23:57 +11:00
|
|
|
asterisk = '*'
|
|
|
|
else:
|
|
|
|
asterisk = ''
|
2022-10-15 23:11:22 +11:00
|
|
|
|
2022-10-18 19:23:57 +11:00
|
|
|
if html:
|
|
|
|
# Escape angle quotes
|
|
|
|
sign = sign.replace('<', '<')
|
|
|
|
sign = sign.replace('>', '>')
|
2022-10-15 23:30:41 +11:00
|
|
|
|
2022-10-18 19:23:57 +11:00
|
|
|
if only_value:
|
|
|
|
return '{}{}{}'.format(sign, fmt, asterisk)
|
|
|
|
elif tabular:
|
|
|
|
# Always left-aligned, so reserve space for sign if required to align decimal points
|
|
|
|
if not sign:
|
|
|
|
sign = '<span style="visibility:hidden">=</span>'
|
|
|
|
|
|
|
|
return '{}{}{}'.format(sign, fmt, asterisk)
|
2022-10-13 12:53:18 +11:00
|
|
|
else:
|
2022-10-18 19:23:57 +11:00
|
|
|
# Non-tabular so force a sign
|
|
|
|
if not sign:
|
|
|
|
sign = '='
|
|
|
|
return '{} {}{}'.format(sign, fmt, asterisk)
|
2022-10-13 12:53:18 +11:00
|
|
|
else:
|
2022-10-18 19:23:57 +11:00
|
|
|
if only_value:
|
|
|
|
return '{}{}{}'.format(sign, fmt, asterisk)
|
|
|
|
elif tabular:
|
|
|
|
# Right-aligned, so add spaces to simulate left alignment
|
|
|
|
if not sign:
|
|
|
|
sign = ' '
|
|
|
|
|
|
|
|
# +1 for decimal point
|
|
|
|
# +1 for sign
|
|
|
|
# +1 for asterisk
|
|
|
|
pvalue_max_len = config.pvalue_max_dps + 3
|
|
|
|
if config.pvalue_leading_zero:
|
|
|
|
pvalue_max_len += 1
|
|
|
|
|
|
|
|
# Now add spaces
|
|
|
|
rpadding = ' ' * (pvalue_max_len - len(sign + fmt + asterisk))
|
|
|
|
|
|
|
|
return '{}{}{}{}'.format(sign, fmt, asterisk, rpadding)
|
2022-10-13 12:53:18 +11:00
|
|
|
else:
|
2022-10-18 19:23:57 +11:00
|
|
|
# Non-tabular so force a sign
|
|
|
|
if not sign:
|
|
|
|
sign = '='
|
|
|
|
return '{} {}{}'.format(sign, fmt, asterisk)
|
2022-10-13 12:53:18 +11:00
|
|
|
|
2022-10-15 00:54:46 +11:00
|
|
|
# ------------------------------
|
|
|
|
# General result-related classes
|
|
|
|
|
2022-10-18 17:57:19 +11:00
|
|
|
class ConfidenceInterval:
|
|
|
|
"""A confidence interval"""
|
|
|
|
|
|
|
|
def __init__(self, lower, upper):
|
2022-10-17 21:41:19 +11:00
|
|
|
#: Lower confidence limit (*float*)
|
2022-10-18 17:57:19 +11:00
|
|
|
self.lower = lower
|
2022-10-17 21:41:19 +11:00
|
|
|
#: Upper confidence limit (*float*)
|
2022-10-18 17:57:19 +11:00
|
|
|
self.upper = upper
|
|
|
|
|
2022-10-18 18:44:04 +11:00
|
|
|
def __repr__(self):
|
|
|
|
if config.repr_is_summary:
|
|
|
|
return self.summary()
|
|
|
|
return super().__repr__()
|
|
|
|
|
2022-10-18 17:57:19 +11:00
|
|
|
def _repr_html_(self):
|
|
|
|
return self.summary()
|
|
|
|
|
|
|
|
def summary(self):
|
2022-10-17 21:41:19 +11:00
|
|
|
"""
|
|
|
|
Return a stringified summary of the confidence interval
|
|
|
|
|
|
|
|
:rtype: str
|
|
|
|
"""
|
|
|
|
|
2022-10-18 17:57:19 +11:00
|
|
|
return '{:.2f}–{:.2f}'.format(self.lower, self.upper)
|
2022-10-17 21:41:19 +11:00
|
|
|
|
2022-10-13 12:53:18 +11:00
|
|
|
class Estimate:
|
|
|
|
"""A point estimate and surrounding confidence interval"""
|
|
|
|
|
|
|
|
def __init__(self, point, ci_lower, ci_upper):
|
2022-10-17 21:41:19 +11:00
|
|
|
#: Point estimate (*float*)
|
2022-10-13 12:53:18 +11:00
|
|
|
self.point = point
|
2022-10-17 21:41:19 +11:00
|
|
|
#: Lower confidence limit (*float*)
|
2022-10-13 12:53:18 +11:00
|
|
|
self.ci_lower = ci_lower
|
2022-10-17 21:41:19 +11:00
|
|
|
#: Upper confidence limit (*float*)
|
2022-10-13 12:53:18 +11:00
|
|
|
self.ci_upper = ci_upper
|
|
|
|
|
2022-10-18 18:44:04 +11:00
|
|
|
def __repr__(self):
|
|
|
|
if config.repr_is_summary:
|
|
|
|
return self.summary()
|
|
|
|
return super().__repr__()
|
|
|
|
|
2022-10-13 12:53:18 +11:00
|
|
|
def _repr_html_(self):
|
|
|
|
return self.summary()
|
|
|
|
|
|
|
|
def summary(self):
|
2022-10-17 21:41:19 +11:00
|
|
|
"""
|
|
|
|
Return a stringified summary of the estimate and confidence interval
|
|
|
|
|
|
|
|
:rtype: str
|
|
|
|
"""
|
|
|
|
|
2022-10-13 12:53:18 +11:00
|
|
|
return '{:.2f} ({:.2f}–{:.2f})'.format(self.point, self.ci_lower, self.ci_upper)
|
|
|
|
|
|
|
|
def __neg__(self):
|
|
|
|
return Estimate(-self.point, -self.ci_upper, -self.ci_lower)
|
|
|
|
|
|
|
|
def __abs__(self):
|
|
|
|
if self.point < 0:
|
|
|
|
return -self
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
|
|
|
def exp(self):
|
|
|
|
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
|
2022-10-15 00:54:46 +11:00
|
|
|
|
|
|
|
# --------------------------
|
|
|
|
# Patsy formula manipulation
|
|
|
|
|
2022-10-15 01:09:40 +11:00
|
|
|
def cols_for_formula(formula, df):
|
2022-10-15 00:54:46 +11:00
|
|
|
"""Return the columns corresponding to the Patsy formula"""
|
|
|
|
|
|
|
|
# Parse the formula
|
|
|
|
model_desc = patsy.ModelDesc.from_formula(formula)
|
|
|
|
|
|
|
|
# Get the columns
|
|
|
|
cols = set()
|
|
|
|
for term in model_desc.rhs_termlist:
|
|
|
|
for factor in term.factors:
|
|
|
|
name = factor.name()
|
2022-10-15 01:09:40 +11:00
|
|
|
if name.startswith('C('):
|
|
|
|
# Contrasts expression
|
|
|
|
# Get the corresponding factor_info
|
|
|
|
factor_info = formula_get_factor_info(formula, df, name)
|
|
|
|
|
|
|
|
# Evaluate the factor
|
|
|
|
categorical_box = factor_info.factor.eval(factor_info.state, df)
|
|
|
|
|
|
|
|
# Get the column name
|
|
|
|
name = categorical_box.data.name
|
2022-10-15 00:54:46 +11:00
|
|
|
|
|
|
|
cols.add(name)
|
|
|
|
|
|
|
|
return list(cols)
|
|
|
|
|
2022-10-15 01:09:40 +11:00
|
|
|
def formula_get_factor_info(formula, df, factor):
|
|
|
|
"""Get the FactorInfo for a factor in a Patsy formula"""
|
2022-10-15 00:54:46 +11:00
|
|
|
|
|
|
|
# Parse the formula
|
|
|
|
design_info = patsy.dmatrix(formula, df).design_info
|
|
|
|
|
|
|
|
# Get the corresponding factor_info
|
|
|
|
factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor)
|
2022-10-15 01:09:40 +11:00
|
|
|
return factor_info
|
|
|
|
|
|
|
|
def formula_factor_ref_category(formula, df, factor):
|
|
|
|
"""Get the reference category for a term in a Patsy formula referring to a categorical factor"""
|
|
|
|
|
|
|
|
if '(' in factor and not factor.startswith('C('):
|
|
|
|
raise Exception('Attempted to get reference category for unknown expression type "{}"'.format(factor))
|
|
|
|
|
|
|
|
# Get the factor_info
|
|
|
|
factor_info = formula_get_factor_info(formula, df, factor)
|
|
|
|
|
|
|
|
if '(' not in factor:
|
|
|
|
# C(...) is not specified, so must be default
|
|
|
|
return factor_info.categories[0]
|
|
|
|
|
|
|
|
# Evaluate the factor
|
|
|
|
categorical_box = factor_info.factor.eval(factor_info.state, df)
|
|
|
|
|
|
|
|
if categorical_box.contrast is None or categorical_box.contrast is patsy.Treatment:
|
|
|
|
# Default Treatment contrast with default reference group: first category
|
|
|
|
return factor_info.categories[0]
|
|
|
|
|
|
|
|
if isinstance(categorical_box.contrast, patsy.Treatment):
|
|
|
|
if categorical_box.contrast.reference is None:
|
|
|
|
# Default reference group: first category
|
|
|
|
return factor_info.categories[0]
|
|
|
|
|
|
|
|
# Specified reference group
|
|
|
|
return categorical_box.contrast.reference
|
|
|
|
|
|
|
|
raise Exception('Attempted to get reference category for unknown contrast type {}'.format(categorical_box.contrast.__class__.__name__))
|
|
|
|
|
|
|
|
def parse_patsy_term(formula, df, term):
|
|
|
|
"""
|
|
|
|
Parse a Patsy term into its component parts
|
|
|
|
|
|
|
|
Returns: factor, column, contrast
|
|
|
|
e.g. "C(x, Treatment(y))[T.z]" -> "C(x, Treatment(y))", "x", "z"
|
|
|
|
"""
|
|
|
|
|
|
|
|
if '(' not in term:
|
|
|
|
if '[' in term:
|
|
|
|
if '[T.' not in term:
|
|
|
|
raise Exception('Attempted to parse term for unknown contrast type "{}"'.format(term))
|
|
|
|
|
|
|
|
# Treatment contrast term
|
|
|
|
factor = term[:term.index('[T.')]
|
|
|
|
contrast = term[term.index('[T.')+3:term.index(']')]
|
|
|
|
|
|
|
|
return factor, factor, contrast
|
|
|
|
else:
|
|
|
|
# Nothing special
|
|
|
|
return term, term, None
|
|
|
|
|
|
|
|
# Term contains '('
|
|
|
|
|
|
|
|
if not term.startswith('C('):
|
|
|
|
raise Exception('Attempted to parse term for unknown expression type "{}"'.format(term))
|
|
|
|
|
|
|
|
if '[' in term:
|
|
|
|
if '[T.' not in term:
|
|
|
|
raise Exception('Attempted to parse term for unknown contrast type "{}"'.format(term))
|
|
|
|
|
|
|
|
# Treatment contrast term
|
|
|
|
factor = term[:term.index('[T.')]
|
|
|
|
contrast = term[term.index('[T.')+3:term.index(']')]
|
|
|
|
else:
|
|
|
|
# Not a treatment contrast (I think this is impossible?)
|
|
|
|
raise Exception('Attempted to parse unsupported contrast-like term with no contrasts')
|
|
|
|
|
|
|
|
factor_inner = factor[factor.index('(')+1:factor.rindex(')')]
|
|
|
|
if ',' in factor_inner:
|
|
|
|
column = factor_inner[:factor_inner.index(',')]
|
|
|
|
else:
|
|
|
|
column = factor_inner
|
2022-10-15 00:54:46 +11:00
|
|
|
|
2022-10-15 01:09:40 +11:00
|
|
|
return factor, column, contrast
|