From 7e8418eb364efeba0ff657fcd58a606cc7dd3bc1 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Thu, 13 Oct 2022 13:25:24 +1100 Subject: [PATCH] Implement chi2 --- tests/test_chi2.py | 69 ++++++++++++++++++++++++++++++++++++++++++++ yli/__init__.py | 2 +- yli/sig_tests.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 tests/test_chi2.py diff --git a/tests/test_chi2.py b/tests/test_chi2.py new file mode 100644 index 0000000..c89081a --- /dev/null +++ b/tests/test_chi2.py @@ -0,0 +1,69 @@ +# scipy-yli: Helpful SciPy utilities and recipes +# Copyright © 2022 Lee Yingtong Li (RunasSudo) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from pytest import approx + +import numpy as np +import pandas as pd + +import yli + +def test_chi2_ol10_15(): + """Compare yli.chi2 for Ott & Longnecker (2016) example 10.15""" + + data = [ + (1, 'Moderate', 15), + (2, 'Moderate', 32), + (3, 'Moderate', 18), + (4, 'Moderate', 5), + (1, 'Mildly Severe', 8), + (2, 'Mildly Severe', 29), + (3, 'Mildly Severe', 23), + (4, 'Mildly Severe', 18), + (1, 'Severe', 1), + (2, 'Severe', 20), + (3, 'Severe', 25), + (4, 'Severe', 22) + ] + + df = pd.DataFrame({ + 'AgeCategory': np.repeat([d[0] for d in data], [d[2] for d in data]), + 'Severity': np.repeat([d[1] for d in data], [d[2] for d in data]) + }) + + result = yli.chi2(df, 'Severity', 'AgeCategory') + assert result.statistic == approx(27.13, abs=0.01) + assert result.pvalue == approx(0.00014, abs=0.00001) + +def test_chi2_ol10_18(): + """Compare yli.chi2 for Ott & Longnecker (2016) example 10.18""" + + data = [ + (False, False, 250), + (True, False, 750), + (False, True, 400), + (True, True, 1600) + ] + + df = pd.DataFrame({ + 'Response': np.repeat([d[0] for d in data], [d[2] for d in data]), + 'Stress': np.repeat([d[1] for d in data], [d[2] for d in data]) + }) + + result = yli.chi2(df, 'Stress', 'Response') + assert result.oddsratio.point == approx(1.333, abs=0.001) + assert result.oddsratio.ci_lower == approx(1.113, abs=0.001) + assert result.oddsratio.ci_upper == approx(1.596, abs=0.001) diff --git a/yli/__init__.py b/yli/__init__.py index 48f931a..2a3748e 100644 --- a/yli/__init__.py +++ b/yli/__init__.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist -from .sig_tests import mannwhitney, ttest_ind +from .sig_tests import chi2, mannwhitney, ttest_ind def reload_me(): import importlib diff --git a/yli/sig_tests.py b/yli/sig_tests.py index 2dca8bd..bca31b0 100644 --- a/yli/sig_tests.py +++ b/yli/sig_tests.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import numpy as np import pandas as pd from scipy import stats import statsmodels.api as sm @@ -155,3 +156,73 @@ def mannwhitney(df, dep, ind, *, nan_policy='warn', brunnermunzel=True, use_cont statistic=min(u1, u2), pvalue=result.pvalue, #med1=data1.median(), med2=data2.median(), rank_biserial=r, direction=('{1} > {0}' if u1 < u2 else '{0} > {1}').format(group1, group2)) + +# ------------------------ +# Pearson chi-squared test + +class PearsonChiSquaredResult: + """Result of a Pearson chi-squared test""" + + def __init__(self, ct, statistic, dof, pvalue, oddsratio=None, riskratio=None): + self.ct = ct + self.statistic = statistic + self.dof = dof + self.pvalue = pvalue + self.oddsratio = oddsratio + self.riskratio = riskratio + + def _repr_html_(self): + if self.oddsratio is not None: + return '{}
χ2({}) = {:.2f}; p {}
OR (95% CI) = {}
RR (95% CI) = {}'.format( + self.ct._repr_html_(), self.dof, self.statistic, fmt_p_html(self.pvalue), self.oddsratio.summary(), self.riskratio.summary()) + else: + return '{}
χ2({}) = {:.2f}; p {}'.format( + self.ct._repr_html_(), self.dof, self.statistic, fmt_p_html(self.pvalue)) + + def summary(self): + if self.oddsratio is not None: + return '{}\nχ²({}) = {:.2f}; p {}\nOR (95% CI) = {}\nRR (95% CI) = {}'.format( + self.ct, self.dof, self.statistic, fmt_p_text(self.pvalue), self.oddsratio.summary(), self.riskratio.summary()) + else: + return '{}\nχ²({}) = {:.2f}; p {}'.format( + self.ct, self.dof, self.statistic, fmt_p_text(self.pvalue)) + +def chi2(df, dep, ind, *, nan_policy='warn'): + """ + Perform a Pearson chi-squared test + """ + + # Check for/clean NaNs + df = check_nan(df[[ind, dep]], nan_policy) + + # Compute contingency table + ct = pd.crosstab(df[ind], df[dep]) + + # Get expected counts + expected = stats.contingency.expected_freq(ct) + + # Warn on low expected counts + if (expected < 5).sum() / expected.size > 0.2: + warnings.warn('{} of {} cells ({:.0f}%) have expected count < 5'.format((expected < 5).sum(), expected.size, (expected < 5).sum() / expected.size * 100)) + if (expected < 1).any(): + warnings.warn('{} cells have expected count < 1'.format((expected < 1).sum())) + + if ct.shape == (2,2): + # 2x2 table + # Use statsmodels to get OR andRR + + smct = sm.stats.Table2x2(np.flip(ct.to_numpy()), shift_zeros=False) + result = smct.test_nominal_association() + ORci = smct.oddsratio_confint() + RRci = smct.riskratio_confint() + + return PearsonChiSquaredResult( + ct=ct, statistic=result.statistic, dof=result.df, pvalue=result.pvalue, + oddsratio=Estimate(smct.oddsratio, ORci[0], ORci[1]), riskratio=Estimate(smct.riskratio, RRci[0], RRci[1])) + else: + # rxc table + # Just use SciPy + + result = stats.chi2_contingency(ct, correction=False) + + return PearsonChiSquaredResult(ct=ct, statistic=result[0], dof=result[2], pvalue=result[1])