scipy-yli/yli/utils.py

478 lines
13 KiB
Python

# scipy-yli: Helpful SciPy utilities and recipes
# Copyright © 2022 Lee Yingtong Li (RunasSudo)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import numpy as np
import pandas as pd
import patsy
import enum
import warnings
from .config import config
# ----------------------------
# Data cleaning and validation
def check_nan(df, nan_policy, *, cols=None):
"""
Check df against *nan_policy* and return cleaned input
:param df: Data to check for NaNs
:type df: DataFrame
:param nan_policy: Policy to apply when encountering NaN values (*warn*, *raise*, *omit*)
:type nan_policy: str
:param cols: Columns to check for NaN, or *None* for all columns
:type cols: List[str]
:return: Data with NaNs removed, which may or may not be copied
:rtype: DataFrame
"""
if nan_policy == 'raise':
df_to_check = df if cols is None else df[cols]
if pd.isna(df_to_check).any(axis=None):
raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore')
return df
elif nan_policy == 'warn':
df_cleaned = df.dropna(subset=cols)
if len(df_cleaned) < len(df):
warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned)))
return df_cleaned
elif nan_policy == 'omit':
return df.dropna(subset=cols)
else:
raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
def convert_pandas_nullable(df):
"""
Convert pandas nullable dtypes (e.g. *Int64*) to non-nullable numpy dtypes
Behaviour on encountering *NA* values is undefined, so the data should be passed through :func:`check_nan` first.
:param df: Data to check for pandas nullable dtypes
:type df: DataFrame
:return: Data with pandas nullable dtypes converted, which may or may not be copied
:rtype: DataFrame
"""
# Avoid copy if possible
df_cleaned = None
for col in df.columns:
if df[col].dtype in ('Int64', 'Float64'):
if df_cleaned is None:
df_cleaned = df.copy()
df_cleaned[col] = df[col].astype(str(df[col].dtype).lower())
elif df[col].dtype == 'boolean':
if df_cleaned is None:
df_cleaned = df.copy()
df_cleaned[col] = df[col].astype('bool')
if df_cleaned is None:
return df
return df_cleaned
def as_2groups(df, data, group):
"""
Group the data by the given variable, asserting only 2 groups
:param df: Data to group
:type df: DataFrame
:param group: Column to group by
:type group: str
:return: (*group1*, *data1*, *group2*, *data2*)
* **group1**, **group2** (*str*) – The 2 values of the grouping variable
* **data1**, **data2** (*DataFrame*) – The 2 corresponding subsets of *df*
"""
# Get groupings
groups = list(df.groupby(group).groups.items())
# Ensure only 2 groups to compare
if len(groups) != 2:
raise Exception('Got {} values for {}, expected 2'.format(len(groups), group))
# Get 2 groups
# FIXME: Sort order
group1 = groups[0][0]
data1 = df.loc[groups[0][1], data]
group2 = groups[1][0]
data2 = df.loc[groups[1][1], data]
return group1, data1, group2, data2
def as_numeric(data):
"""
Convert the data to a numeric type, factorising if required
:param data: Data to convert
:type df: Series
:return: See *pandas.factorize*
"""
if data.dtype == 'float64':
return data, None
if data.dtype == 'category' and data.cat.categories.dtype == 'object':
return data.factorize(sort=True)
return data.astype('float64'), None
def as_ordinal(data):
"""
Convert the data to an ordered category dtype
:param data: Data to convert
:type df: Series
:rtype: Series
"""
if data.dtype == 'category':
if data.cat.ordered:
return data
return data.cat.as_ordered()
return data.astype(pd.CategoricalDtype(ordered=True))
# ----------
# Formatting
def do_fmt_p(p):
"""Return sign and formatted p value"""
if p < 10**-config.pvalue_max_dps:
# Smaller than min value
return '<', '{:.1g}'.format(10**-config.pvalue_max_dps)
if p > 1 - 10**-config.pvalue_min_dps:
# Larger than max value
return '>', '{0:.{dps}f}'.format(1 - 10**-config.pvalue_min_dps, dps=config.pvalue_min_dps)
if round(p, config.pvalue_min_dps) == config.alpha:
# Rounding to pvalue_min_dps makes significance ambiguous
if round(p, config.pvalue_max_dps) == config.alpha:
# Still ambiguous to pvalue_max_dps
if p < config.alpha:
# Significant: round down
p = config.alpha - 10**-config.pvalue_max_dps
else:
# Nonsignificant: round up
p = config.alpha + 10**-config.pvalue_max_dps
return '', '{0:.{dps}f}'.format(p, dps=config.pvalue_max_dps)
if p < 10**-config.pvalue_min_dps:
# Insufficient resolution at pvalue_min_dps
# We know from earlier comparison that 1 s.f. fits within pvalue_max_dps
return '', '{:.1g}'.format(p)
# OK to round to pvalue_min_dps
return '', '{0:.{dps}f}'.format(p, dps=config.pvalue_min_dps)
class PValueStyle(enum.Flag):
"""An *enum.Flag* representing how to render a *p* value"""
VALUE_ONLY = 0
RELATION = enum.auto()
TABULAR = enum.auto()
HTML = enum.auto()
def fmt_p(p, style):
"""
Format *p* value for display
:param p: *p* value to display
:type p: float
:param style: Style to format the *p* value
:type style: :class:`PValueStyle`
:return: Formatted *p* value
:rtype: str
"""
sign, fmt = do_fmt_p(p)
# Strip leading zero if required
if not config.pvalue_leading_zero:
fmt = fmt.lstrip('0')
# Check if significant
if p < config.alpha:
asterisk = '*'
else:
asterisk = ''
if PValueStyle.HTML in style:
# Escape angle quotes
sign = sign.replace('<', '&lt;')
sign = sign.replace('>', '&gt;')
if PValueStyle.RELATION in style:
# Add relational operator
if not sign:
sign = '='
return '{} {}{}'.format(sign, fmt, asterisk)
elif PValueStyle.TABULAR in style:
# Always left-aligned, so reserve space for sign if required to align decimal points
if not sign:
sign = '<span style="visibility:hidden">=</span>'
return '{}{}{}'.format(sign, fmt, asterisk)
else:
# Only value
return '{}{}{}'.format(sign, fmt, asterisk)
else:
if PValueStyle.RELATION in style:
# Add relational operator
if not sign:
sign = '='
return '{} {}{}'.format(sign, fmt, asterisk)
elif PValueStyle.TABULAR in style:
# Right-aligned, so add spaces to simulate left alignment
if not sign:
sign = ' '
# +1 for decimal point
# +1 for sign
# +1 for asterisk
pvalue_max_len = config.pvalue_max_dps + 3
if config.pvalue_leading_zero:
pvalue_max_len += 1
# Now add spaces
rpadding = ' ' * (pvalue_max_len - len(sign + fmt + asterisk))
return '{}{}{}{}'.format(sign, fmt, asterisk, rpadding)
else:
# Only value
return '{}{}{}'.format(sign, fmt, asterisk)
# ------------------------------
# General result-related classes
class Interval:
"""An interval (e.g. confidence interval)"""
def __init__(self, lower, upper):
#: Lower limit (*float*)
self.lower = lower
#: Upper limit (*float*)
self.upper = upper
def __repr__(self):
if config.repr_is_summary:
return self.summary()
return super().__repr__()
def _repr_html_(self):
return self.summary()
def summary(self):
"""
Return a stringified summary of the interval
:rtype: str
"""
return '{:.2f}{:.2f}'.format(self.lower, self.upper)
class Estimate:
"""A point estimate and surrounding confidence interval"""
def __init__(self, point, ci_lower, ci_upper):
#: Point estimate (*float*)
self.point = point
#: Lower confidence limit (*float*)
self.ci_lower = ci_lower
#: Upper confidence limit (*float*)
self.ci_upper = ci_upper
def __repr__(self):
if config.repr_is_summary:
return self.summary()
return super().__repr__()
def _repr_html_(self):
return self.summary()
def summary(self):
"""
Return a stringified summary of the estimate and confidence interval
:rtype: str
"""
return '{:.2f} ({:.2f}{:.2f})'.format(self.point, self.ci_lower, self.ci_upper)
def __neg__(self):
return Estimate(-self.point, -self.ci_upper, -self.ci_lower)
def __abs__(self):
if self.point < 0:
return -self
else:
return self
def exp(self):
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
# --------------------------
# Patsy formula manipulation
def cols_for_formula(formula, df):
"""
Return the columns corresponding to the Patsy formula
:param formula: Patsy formula to parse
:type formula: str
:param df: Data to apply the formula on
:type df: DataFrame
:return: Columns in (the right-hand side of) the formula
:rtype: List[str]
"""
# Parse the formula
model_desc = patsy.ModelDesc.from_formula(formula)
# Get the columns
cols = set()
for term in model_desc.rhs_termlist:
for factor in term.factors:
name = factor.name()
if name.startswith('C('):
# Contrasts expression
# Get the corresponding factor_info
factor_info = formula_get_factor_info(formula, df, name)
# Evaluate the factor
categorical_box = factor_info.factor.eval(factor_info.state, df)
# Get the column name
name = categorical_box.data.name
cols.add(name)
return list(cols)
def formula_get_factor_info(formula, df, factor):
"""Get the FactorInfo for a factor in a Patsy formula"""
# Parse the formula
design_info = patsy.dmatrix(formula, df).design_info
# Get the corresponding factor_info
factor_info = next(v for k, v in design_info.factor_infos.items() if k.name() == factor)
return factor_info
def formula_factor_ref_category(formula, df, factor):
"""
Get the reference category for a term in a Patsy formula referring to a categorical factor
:param formula: Patsy formula to parse
:type formula: str
:param df: Data to apply the formula on
:type df: DataFrame
:param factor: Factor to determine reference category for (e.g. ``Country``, ``C(Country)``, ``C(Country, Treatment)``, ``C(Country, Treatment("Australia"))``)
:return: Reference category for the specified factor
"""
if '(' in factor and not factor.startswith('C('):
raise Exception('Attempted to get reference category for unknown expression type "{}"'.format(factor))
# Get the factor_info
factor_info = formula_get_factor_info(formula, df, factor)
if '(' not in factor:
# C(...) is not specified, so must be default
return factor_info.categories[0]
# Evaluate the factor
categorical_box = factor_info.factor.eval(factor_info.state, df)
if categorical_box.contrast is None or categorical_box.contrast is patsy.Treatment:
# Default Treatment contrast with default reference group: first category
return factor_info.categories[0]
if isinstance(categorical_box.contrast, patsy.Treatment):
if categorical_box.contrast.reference is None:
# Default reference group: first category
return factor_info.categories[0]
# Specified reference group
return categorical_box.contrast.reference
raise Exception('Attempted to get reference category for unknown contrast type {}'.format(categorical_box.contrast.__class__.__name__))
def parse_patsy_term(formula, df, term):
"""
Parse a Patsy term into its component parts
**Example:** The term ``"C(x, Treatment(y))[T.z]"`` parses to ``("C(x, Treatment(y))", "x", "z")``.
:return: (*factor*, *column*, *contrast*)
* **factor** (*str*) – Name of the factor, as specified in the Patsy formula
* **column** (*str*) – Name of the DataFrame column corresponding to the factor
* **contrast** (*str*) – Name of the contrast for the factor, or *None* if not applicable
"""
if '(' not in term:
if '[' in term:
if '[T.' not in term:
raise Exception('Attempted to parse term for unknown contrast type "{}"'.format(term))
# Treatment contrast term
factor = term[:term.index('[T.')]
contrast = term[term.index('[T.')+3:term.index(']')]
return factor, factor, contrast
else:
# Nothing special
return term, term, None
# Term contains '('
if not term.startswith('C('):
raise Exception('Attempted to parse term for unknown expression type "{}"'.format(term))
if '[' in term:
if '[T.' not in term:
raise Exception('Attempted to parse term for unknown contrast type "{}"'.format(term))
# Treatment contrast term
factor = term[:term.index('[T.')]
contrast = term[term.index('[T.')+3:term.index(']')]
else:
# Not a treatment contrast (I think this is impossible?)
raise Exception('Attempted to parse unsupported contrast-like term with no contrasts')
factor_inner = factor[factor.index('(')+1:factor.rindex(')')]
if ',' in factor_inner:
column = factor_inner[:factor_inner.index(',')]
else:
column = factor_inner
return factor, column, contrast