120 lines
3.3 KiB
Python
120 lines
3.3 KiB
Python
# scipy-yli: Helpful SciPy utilities and recipes
|
|
# Copyright © 2022 Lee Yingtong Li (RunasSudo)
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
import warnings
|
|
|
|
def check_nan(df, nan_policy):
|
|
"""Check df against nan_policy and return cleaned input"""
|
|
|
|
if nan_policy == 'raise':
|
|
if pd.isna(df).any(axis=None):
|
|
raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore')
|
|
elif nan_policy == 'warn':
|
|
df_cleaned = df.dropna()
|
|
if len(df_cleaned) < len(df):
|
|
warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned)))
|
|
return df_cleaned
|
|
elif nan_policy == 'omit':
|
|
return df.dropna()
|
|
else:
|
|
raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
|
|
|
|
def as_2groups(df, data, group):
|
|
"""Group the data by the given variable, ensuring only 2 groups"""
|
|
|
|
# Get groupings
|
|
groups = list(df.groupby(group).groups.items())
|
|
|
|
# Ensure only 2 groups to compare
|
|
if len(groups) != 2:
|
|
raise Exception('Got {} values for {}, expected 2'.format(len(groups), group))
|
|
|
|
# Get 2 groups
|
|
group1 = groups[0][0]
|
|
data1 = df.loc[groups[0][1], data]
|
|
group2 = groups[1][0]
|
|
data2 = df.loc[groups[1][1], data]
|
|
|
|
return group1, data1, group2, data2
|
|
|
|
def do_fmt_p(p):
|
|
"""Return sign and formatted p value"""
|
|
|
|
if p < 0.001:
|
|
return '<', '0.001*'
|
|
elif p < 0.0095:
|
|
return None, '{:.3f}*'.format(p)
|
|
elif p < 0.045:
|
|
return None, '{:.2f}*'.format(p)
|
|
elif p < 0.05:
|
|
return None, '{:.3f}*'.format(p) # 3dps to show significance
|
|
elif p < 0.055:
|
|
return None, '{:.3f}'.format(p) # 3dps to show non-significance
|
|
elif p < 0.095:
|
|
return None, '{:.2f}'.format(p)
|
|
else:
|
|
return None, '{:.1f}'.format(p)
|
|
|
|
def fmt_p_text(p, nospace=False):
|
|
"""Format p value for plaintext"""
|
|
|
|
sign, fmt = do_fmt_p(p)
|
|
if sign is not None:
|
|
if nospace:
|
|
return sign + fmt # e.g. "<0.001"
|
|
else:
|
|
return sign + ' ' + fmt # e.g. "< 0.001"
|
|
else:
|
|
if nospace:
|
|
return fmt # e.g. "0.05"
|
|
else:
|
|
return '= ' + fmt # e.g. "= 0.05"
|
|
|
|
def fmt_p_html(p, nospace=False):
|
|
"""Format p value for HTML"""
|
|
|
|
txt = fmt_p_text(p, nospace)
|
|
return txt.replace('<', '<')
|
|
|
|
class Estimate:
|
|
"""A point estimate and surrounding confidence interval"""
|
|
|
|
def __init__(self, point, ci_lower, ci_upper):
|
|
self.point = point
|
|
self.ci_lower = ci_lower
|
|
self.ci_upper = ci_upper
|
|
|
|
def _repr_html_(self):
|
|
return self.summary()
|
|
|
|
def summary(self):
|
|
return '{:.2f} ({:.2f}–{:.2f})'.format(self.point, self.ci_lower, self.ci_upper)
|
|
|
|
def __neg__(self):
|
|
return Estimate(-self.point, -self.ci_upper, -self.ci_lower)
|
|
|
|
def __abs__(self):
|
|
if self.point < 0:
|
|
return -self
|
|
else:
|
|
return self
|
|
|
|
def exp(self):
|
|
return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))
|