diff --git a/yli/__init__.py b/yli/__init__.py index 339caef..e7abd11 100644 --- a/yli/__init__.py +++ b/yli/__init__.py @@ -14,7 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from .distributions import * +from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist +from .sig_tests import ttest_ind def reload_me(): import importlib diff --git a/yli/sig_tests.py b/yli/sig_tests.py new file mode 100644 index 0000000..319eb68 --- /dev/null +++ b/yli/sig_tests.py @@ -0,0 +1,140 @@ +# scipy-yli: Helpful SciPy utilities and recipes +# Copyright © 2022 Lee Yingtong Li (RunasSudo) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import pandas as pd +from scipy import stats +import statsmodels.api as sm + +import functools +import warnings + +def check_nan(df, nan_policy): + """Check df against nan_policy and return cleaned input""" + + if nan_policy == 'raise': + if pd.isna(df).any(axis=None): + raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore') + elif nan_policy == 'warn': + df_cleaned = df.dropna() + if len(df_cleaned) < len(df): + warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned))) + return df_cleaned + elif nan_policy == 'omit': + return df.dropna() + else: + raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"') + +def do_fmt_p(p): + """Return sign and formatted p value""" + + if p < 0.001: + return '<', '0.001*' + elif p < 0.0095: + return None, '{:.3f}*'.format(p) + elif p < 0.045: + return None, '{:.2f}*'.format(p) + elif p < 0.05: + return None, '{:.3f}*'.format(p) # 3dps to show significance + elif p < 0.055: + return None, '{:.3f}'.format(p) # 3dps to show non-significance + elif p < 0.095: + return None, '{:.2f}'.format(p) + else: + return None, '{:.1f}'.format(p) + +def fmt_p_text(p, nospace=False): + """Format p value for plaintext""" + + sign, fmt = do_fmt_p(p) + if sign is not None: + if nospace: + return sign + fmt # e.g. "<0.001" + else: + return sign + ' ' + fmt # e.g. "< 0.001" + else: + if nospace: + return fmt # e.g. "0.05" + else: + return '= ' + fmt # e.g. "= 0.05" + +def fmt_p_html(p, nospace=False): + """Format p value for HTML""" + + txt = fmt_p_text(p, nospace) + return txt.replace('<', '<') + +class Estimate: + """A point estimate and surrounding confidence interval""" + + def __init__(self, point, ci_lower, ci_upper): + self.point = point + self.ci_lower = ci_lower + self.ci_upper = ci_upper + + def _repr_html_(self): + return self.summary() + + def summary(self): + return '{:.2f} ({:.2f}–{:.2f})'.format(self.point, self.ci_lower, self.ci_upper) + +class TTestResult: + """ + Result of a Student's t test + + delta: Mean difference + """ + + def __init__(self, statistic, dof, pvalue, delta): + self.statistic = statistic + self.dof = dof + self.pvalue = pvalue + self.delta = delta + + def _repr_html_(self): + return 't({:.0f}) = {:.2f}; p {}
δ (95% CI) = {}'.format(self.dof, self.statistic, fmt_p_html(self.pvalue), self.delta.summary()) + + def summary(self): + return 't({:.0f}) = {:.2f}; p {}\nδ (95% CI) = {}'.format(self.dof, self.statistic, fmt_p_text(self.pvalue), self.delta.summary()) + +def ttest_ind(df, dep, ind, *, nan_policy='warn'): + """Perform an independent-sample Student's t test""" + + df = check_nan(df[[ind, dep]], nan_policy) + + # Get groupings for ind + groups = list(df.groupby(ind).groups.values()) + + # Ensure only 2 groups to compare + if len(groups) != 2: + raise Exception('Got {} values for {}, expected 2'.format(len(groups), ind)) + + # Get 2 groups + group1 = df.loc[groups[0], dep] + group2 = df.loc[groups[1], dep] + + # Do t test + # Use statsmodels rather than SciPy because this provides the mean difference automatically + d1 = sm.stats.DescrStatsW(group1) + d2 = sm.stats.DescrStatsW(group2) + + cm = sm.stats.CompareMeans(d2, d1) # This order to get correct CI + statistic, pvalue, dof = cm.ttest_ind() + + delta = d2.mean - d1.mean + ci0, ci1 = cm.tconfint_diff() + + return TTestResult(statistic=statistic, dof=dof, pvalue=pvalue, delta=Estimate(delta, ci0, ci1)) +0 \ No newline at end of file