From 3d30045832e0536b56795334decb81aee744b0af Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sat, 22 Apr 2023 01:18:02 +1000 Subject: [PATCH] Add unit test for yli.kaplanmeier --- tests/test_kaplanmeier.py | 58 +++++++++++++++++++++++++++++++++++++++ yli/survival.py | 15 ++++++++-- 2 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 tests/test_kaplanmeier.py diff --git a/tests/test_kaplanmeier.py b/tests/test_kaplanmeier.py new file mode 100644 index 0000000..f3ab035 --- /dev/null +++ b/tests/test_kaplanmeier.py @@ -0,0 +1,58 @@ +# scipy-yli: Helpful SciPy utilities and recipes +# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from pytest import approx + +import pandas as pd + +import yli + +def test_kaplanmeier_simple(): + """Compare yli.kaplanmeier for simple example""" + + df = pd.DataFrame({ + 'SurvTime': [2, 4, 6, 8], + 'Status': [True, True, True, True] + }) + + xpoints, ypoints, _, _ = yli.survival.calc_survfunc_kaplanmeier(df['SurvTime'], df['Status'], False) + + assert xpoints[0] == 0 + assert ypoints[0] == 1 + + assert xpoints[1] == 2 + assert ypoints[1] == 1 + + assert xpoints[2] == 2 + assert ypoints[2] == 0.75 + + assert xpoints[3] == 4 + assert ypoints[3] == 0.75 + + assert xpoints[4] == 4 + assert ypoints[4] == approx(0.5) + + assert xpoints[5] == 6 + assert ypoints[5] == approx(0.5) + + assert xpoints[6] == 6 + assert ypoints[6] == approx(0.25) + + assert xpoints[7] == 8 + assert ypoints[7] == approx(0.25) + + assert xpoints[8] == 8 + assert ypoints[8] == 0 diff --git a/yli/survival.py b/yli/survival.py index 9f68e54..aa36fb3 100644 --- a/yli/survival.py +++ b/yli/survival.py @@ -87,6 +87,16 @@ def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transfo return fig, ax def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_y=None): + xpoints, ypoints, ypoints0, ypoints1 = calc_survfunc_kaplanmeier(time, status, ci, transform_x, transform_y) + + handle = ax.plot(xpoints, ypoints)[0] + + if ci: + ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_') + + return handle + +def calc_survfunc_kaplanmeier(time, status, ci, transform_x=None, transform_y=None): # Estimate the survival function sf = sm.SurvfuncRight(time, status) @@ -94,7 +104,6 @@ def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_ # np.concatenate(...) to force starting drawing from time 0, survival 100% xpoints = np.concatenate([[0], sf.surv_times]).repeat(2)[1:] ypoints = np.concatenate([[1], sf.surv_prob]).repeat(2)[:-1] - handle = ax.plot(xpoints, ypoints)[0] if transform_x: xpoints = transform_x(xpoints) @@ -116,9 +125,9 @@ def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_ ypoints0 = transform_y(ypoints0) ypoints1 = transform_y(ypoints1) - ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_') + return xpoints, ypoints, ypoints0, ypoints1 - return handle + return xpoints, ypoints, None, None def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'): """