scipy-yli/yli/survival.py
RunasSudo 8238383edb
Fix up yli.turnbull
Draw step function at configurable point on each Turnbull interval (previous documentation did not correctly describe the behaviour of the function)
Draw survival curve from time 0, survival 100%
2023-04-22 01:18:38 +10:00

287 lines
9.4 KiB
Python

# scipy-yli: Helpful SciPy utilities and recipes
# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import numpy as np
from scipy import stats
import statsmodels.api as sm
from .config import config
from .sig_tests import ChiSquaredResult
from .utils import Estimate, check_nan
def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transform_y=None, nan_policy='warn'):
"""
Generate a Kaplan–Meier plot
Uses the Python *matplotlib* library.
:param df: Data to generate plot for
:type df: DataFrame
:param time: Column in *df* for the time to event (numeric or timedelta)
:type time: str
:param status: Column in *df* for the status variable (True/False or 1/0)
:type status: str
:param by: Column in *df* to stratify by (categorical)
:type by: str
:param ci: Whether to plot confidence intervals around the survival function
:type ci: bool
:param transform_x: Function to transform x axis by
:type transform_x: callable
:param transform_y: Function to transform y axis by
:type transform_y: callable
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
:type nan_policy: str
:rtype: (Figure, Axes)
"""
import matplotlib.pyplot as plt
# Check for/clean NaNs
if by:
df = check_nan(df[[time, status, by]], nan_policy)
else:
df = check_nan(df[[time, status]], nan_policy)
# Covert timedelta to numeric
df, time_units = survtime_to_numeric(df, time)
fig, ax = plt.subplots()
if by is not None:
# Group by independent variable
groups = df.groupby(by)
for group in groups.groups:
subset = groups.get_group(group)
handle = plot_survfunc_kaplanmeier(ax, subset[time], subset[status], ci, transform_x, transform_y)
handle.set_label('{} = {}'.format(by, group))
else:
# No grouping
plot_survfunc_kaplanmeier(ax, df[time], df[status], ci, transform_x, transform_y)
if time_units:
ax.set_xlabel('{} ({})'.format(time, time_units))
else:
ax.set_xlabel(time)
ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
ax.set_xlim(left=0)
ax.set_ylim(0, 1)
ax.legend()
return fig, ax
def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_y=None):
# Estimate the survival function
sf = sm.SurvfuncRight(time, status)
# Draw straight lines
# np.concatenate(...) to force starting drawing from time 0, survival 100%
xpoints = np.concatenate([[0], sf.surv_times]).repeat(2)[1:]
ypoints = np.concatenate([[1], sf.surv_prob]).repeat(2)[:-1]
handle = ax.plot(xpoints, ypoints)[0]
if transform_x:
xpoints = transform_x(xpoints)
if transform_y:
ypoints = transform_y(ypoints)
if ci:
zstar = -stats.norm.ppf(config.alpha/2)
# Get confidence intervals
ci0 = sf.surv_prob - zstar * sf.surv_prob_se
ci1 = sf.surv_prob + zstar * sf.surv_prob_se
# Plot confidence intervals
ypoints0 = np.concatenate([[1], ci0]).repeat(2)[:-1]
ypoints1 = np.concatenate([[1], ci1]).repeat(2)[:-1]
if transform_y:
ypoints0 = transform_y(ypoints0)
ypoints1 = transform_y(ypoints1)
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
return handle
def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'):
"""
Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta.
By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval.
Uses the Python *lifelines* and *matplotlib* libraries.
:param df: Data to generate plot for
:type df: DataFrame
:param time_left: Column in *df* for the time to event, left interval endpoint (numeric)
:type time_left: str
:param time_right: Column in *df* for the time to event, right interval endpoint (numeric)
:type time_right: str
:param by: Column in *df* to stratify by (categorical)
:type by: str
:param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint (numeric)
:type step_loc: float
:param transform_x: Function to transform x axis by
:type transform_x: callable
:param transform_y: Function to transform y axis by
:type transform_y: callable
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
:type nan_policy: str
:rtype: (Figure, Axes)
"""
import matplotlib.pyplot as plt
# Check for/clean NaNs
if by:
df = check_nan(df[[time_left, time_right, by]], nan_policy)
else:
df = check_nan(df[[time_left, time_right]], nan_policy)
fig, ax = plt.subplots()
if by is not None:
# Group by independent variable
groups = df.groupby(by)
for group in groups.groups:
subset = groups.get_group(group)
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], step_loc, transform_x, transform_y)
handle.set_label('{} = {}'.format(by, group))
else:
# No grouping
plot_survfunc_turnbull(ax, df[time_left], df[time_right], step_loc, transform_x, transform_y)
ax.set_xlabel('Analysis time')
ax.set_ylabel('Survival probability')
ax.set_xlim(left=0)
ax.set_ylim(0, 1)
ax.legend()
return fig, ax
def plot_survfunc_turnbull(ax, time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
xpoints, ypoints = calc_survfunc_turnbull(time_left, time_right, step_loc, transform_x, transform_y)
handle = ax.plot(xpoints, ypoints)[0]
return handle
def calc_survfunc_turnbull(time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
from lifelines.fitters.npmle import npmle
EPSILON = 1e-10
# TODO: Support left == right => failure was exactly observed
followup_left = time_left + EPSILON # Add epsilon to make interval half-open
followup_right = time_right
# Estimate the survival function
#sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
# Call lifelines.fitters.npmle.npmle directly so we can compute midpoints, etc.
sf_probs, turnbull_intervals = npmle(followup_left, followup_right)
xpoints = [i.left*(1-step_loc) + i.right*step_loc for i in turnbull_intervals]
ypoints = 1 - np.cumsum(sf_probs)
# Draw straight lines
# np.concatenate(...) to force starting drawing from time 0, survival 100%
xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:]
ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1]
if transform_x:
xpoints = transform_x(xpoints)
if transform_y:
ypoints = transform_y(ypoints)
return xpoints, ypoints
def survtime_to_numeric(df, time):
"""
Convert pandas timedelta dtype to float64, auto-detecting the best time unit to display
:param df: Data to check for pandas timedelta dtype
:type df: DataFrame
:param time: Column to check for pandas timedelta dtype
:type df: DataFrame
:return: (*df*, *time_units*)
* **df** (*DataFrame*) – Data with pandas timedelta dtypes converted, which is *not* copied
* **time_units** (*str*) – Human-readable description of the time unit, or *None* if not converted
"""
if df[time].dtype == '<m8[ns]':
df[time] = df[time].dt.total_seconds()
# Auto-detect best time units
if df[time].max() > 365.24*24*60*60:
df[time] = df[time] / (365.24*24*60*60)
time_units = 'years'
elif df[time].max() > 7*24*60*60 / 12:
df[time] = df[time] / (7*24*60*60)
time_units = 'weeks'
elif df[time].max() > 24*60*60:
df[time] = df[time] / (24*60*60)
time_units = 'days'
elif df[time].max() > 60*60:
df[time] = df[time] / (60*60)
time_units = 'hours'
elif df[time].max() > 60:
df[time] = df[time] / 60
time_units = 'minutes'
else:
time_units = 'seconds'
return df, time_units
else:
return df, None
def logrank(df, time, status, by, nan_policy='warn'):
"""
Perform the log-rank test for equality of survival functions
:param df: Data to perform the test on
:type df: DataFrame
:param time: Column in *df* for the time to event (numeric or timedelta)
:type time: str
:param status: Column in *df* for the status variable (True/False or 1/0)
:type status: str
:param by: Column in *df* to stratify by (categorical)
:type by: str
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
:type nan_policy: str
:rtype: :class:`yli.sig_tests.ChiSquaredResult`
"""
# TODO: Example
# Check for/clean NaNs
df = check_nan(df[[time, status, by]], nan_policy)
if df[time].dtype == '<m8[ns]':
df[time] = df[time].dt.total_seconds()
statistic, pvalue = sm.duration.survdiff(df[time], df[status], df[by])
return ChiSquaredResult(statistic=statistic, dof=1, pvalue=pvalue)