scipy-yli/yli/survival.py

155 lines
4.5 KiB
Python

# scipy-yli: Helpful SciPy utilities and recipes
# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from scipy import stats
import statsmodels.api as sm
from .config import config
from .sig_tests import ChiSquaredResult
from .utils import check_nan
def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'):
"""
Generate a Kaplan–Meier plot
Uses the Python *matplotlib* library.
:param df: Data to generate plot for
:type df: DataFrame
:param time: Column in *df* for the time to event (numeric or timedelta)
:type time: str
:param status: Column in *df* for the status variable (True/False or 1/0)
:type status: str
:param by: Column in *df* to stratify by (categorical)
:type by: str
:param ci: Whether to plot confidence intervals around the survival function
:type ci: bool
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
:type nan_policy: str
:rtype: (Figure, Axes)
"""
import matplotlib.pyplot as plt
# Check for/clean NaNs
if by:
df = check_nan(df[[time, status, by]], nan_policy)
else:
df = check_nan(df[[time, status]], nan_policy)
if df[time].dtype == '<m8[ns]':
df[time] = df[time].dt.total_seconds()
# Auto-detect best time units
if df[time].max() > 365.24*24*60*60:
df[time] = df[time] / (365.24*24*60*60)
time_units = 'years'
elif df[time].max() > 7*24*60*60 / 12:
df[time] = df[time] / (7*24*60*60)
time_units = 'weeks'
elif df[time].max() > 24*60*60:
df[time] = df[time] / (24*60*60)
time_units = 'days'
elif df[time].max() > 60*60:
df[time] = df[time] / (60*60)
time_units = 'hours'
elif df[time].max() > 60:
df[time] = df[time] / 60
time_units = 'minutes'
else:
time_units = 'seconds'
else:
time_units = None
fig, ax = plt.subplots()
if by is not None:
# Group by independent variable
groups = df.groupby(by)
for group in groups.groups:
subset = groups.get_group(group)
handle = plot_survfunc(ax, subset[time], subset[status], ci)
handle.set_label('{} = {}'.format(by, group))
else:
# No grouping
plot_survfunc(ax, df[time], df[status], ci)
if time_units:
ax.set_xlabel('{} ({})'.format(time, time_units))
else:
ax.set_xlabel(time)
ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
ax.set_ylim(0, 1)
ax.legend()
return fig, ax
def plot_survfunc(ax, time, status, ci):
# Estimate the survival function
sf = sm.SurvfuncRight(time, status)
# Draw straight lines
xpoints = sf.surv_times.repeat(2)[1:]
ypoints = sf.surv_prob.repeat(2)[:-1]
handle = ax.plot(xpoints, ypoints)[0]
if ci:
zstar = -stats.norm.ppf(config.alpha/2)
# Get confidence intervals
ci0 = sf.surv_prob - zstar * sf.surv_prob_se
ci1 = sf.surv_prob + zstar * sf.surv_prob_se
# Plot confidence intervals
ypoints0 = ci0.repeat(2)[:-1]
ypoints1 = ci1.repeat(2)[:-1]
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
return handle
def logrank(df, time, status, by, nan_policy='warn'):
"""
Perform the log-rank test for equality of survival functions
:param df: Data to perform the test on
:type df: DataFrame
:param time: Column in *df* for the time to event (numeric or timedelta)
:type time: str
:param status: Column in *df* for the status variable (True/False or 1/0)
:type status: str
:param by: Column in *df* to stratify by (categorical)
:type by: str
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
:type nan_policy: str
:rtype: :class:`yli.sig_tests.ChiSquaredResult`
"""
# TODO: Example
# Check for/clean NaNs
df = check_nan(df[[time, status, by]], nan_policy)
if df[time].dtype == '<m8[ns]':
df[time] = df[time].dt.total_seconds()
statistic, pvalue = sm.duration.survdiff(df[time], df[status], df[by])
return ChiSquaredResult(statistic=statistic, dof=1, pvalue=pvalue)