RunasSudo
8238383edb
Draw step function at configurable point on each Turnbull interval (previous documentation did not correctly describe the behaviour of the function) Draw survival curve from time 0, survival 100%
287 lines
9.4 KiB
Python
287 lines
9.4 KiB
Python
# scipy-yli: Helpful SciPy utilities and recipes
|
|
# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import numpy as np
|
|
from scipy import stats
|
|
import statsmodels.api as sm
|
|
|
|
from .config import config
|
|
from .sig_tests import ChiSquaredResult
|
|
from .utils import Estimate, check_nan
|
|
|
|
def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transform_y=None, nan_policy='warn'):
|
|
"""
|
|
Generate a Kaplan–Meier plot
|
|
|
|
Uses the Python *matplotlib* library.
|
|
|
|
:param df: Data to generate plot for
|
|
:type df: DataFrame
|
|
:param time: Column in *df* for the time to event (numeric or timedelta)
|
|
:type time: str
|
|
:param status: Column in *df* for the status variable (True/False or 1/0)
|
|
:type status: str
|
|
:param by: Column in *df* to stratify by (categorical)
|
|
:type by: str
|
|
:param ci: Whether to plot confidence intervals around the survival function
|
|
:type ci: bool
|
|
:param transform_x: Function to transform x axis by
|
|
:type transform_x: callable
|
|
:param transform_y: Function to transform y axis by
|
|
:type transform_y: callable
|
|
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
|
:type nan_policy: str
|
|
|
|
:rtype: (Figure, Axes)
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Check for/clean NaNs
|
|
if by:
|
|
df = check_nan(df[[time, status, by]], nan_policy)
|
|
else:
|
|
df = check_nan(df[[time, status]], nan_policy)
|
|
|
|
# Covert timedelta to numeric
|
|
df, time_units = survtime_to_numeric(df, time)
|
|
|
|
fig, ax = plt.subplots()
|
|
|
|
if by is not None:
|
|
# Group by independent variable
|
|
groups = df.groupby(by)
|
|
|
|
for group in groups.groups:
|
|
subset = groups.get_group(group)
|
|
handle = plot_survfunc_kaplanmeier(ax, subset[time], subset[status], ci, transform_x, transform_y)
|
|
handle.set_label('{} = {}'.format(by, group))
|
|
else:
|
|
# No grouping
|
|
plot_survfunc_kaplanmeier(ax, df[time], df[status], ci, transform_x, transform_y)
|
|
|
|
if time_units:
|
|
ax.set_xlabel('{} ({})'.format(time, time_units))
|
|
else:
|
|
ax.set_xlabel(time)
|
|
ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
|
|
ax.set_xlim(left=0)
|
|
ax.set_ylim(0, 1)
|
|
ax.legend()
|
|
|
|
return fig, ax
|
|
|
|
def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_y=None):
|
|
# Estimate the survival function
|
|
sf = sm.SurvfuncRight(time, status)
|
|
|
|
# Draw straight lines
|
|
# np.concatenate(...) to force starting drawing from time 0, survival 100%
|
|
xpoints = np.concatenate([[0], sf.surv_times]).repeat(2)[1:]
|
|
ypoints = np.concatenate([[1], sf.surv_prob]).repeat(2)[:-1]
|
|
handle = ax.plot(xpoints, ypoints)[0]
|
|
|
|
if transform_x:
|
|
xpoints = transform_x(xpoints)
|
|
if transform_y:
|
|
ypoints = transform_y(ypoints)
|
|
|
|
if ci:
|
|
zstar = -stats.norm.ppf(config.alpha/2)
|
|
|
|
# Get confidence intervals
|
|
ci0 = sf.surv_prob - zstar * sf.surv_prob_se
|
|
ci1 = sf.surv_prob + zstar * sf.surv_prob_se
|
|
|
|
# Plot confidence intervals
|
|
ypoints0 = np.concatenate([[1], ci0]).repeat(2)[:-1]
|
|
ypoints1 = np.concatenate([[1], ci1]).repeat(2)[:-1]
|
|
|
|
if transform_y:
|
|
ypoints0 = transform_y(ypoints0)
|
|
ypoints1 = transform_y(ypoints1)
|
|
|
|
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
|
|
|
|
return handle
|
|
|
|
def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'):
|
|
"""
|
|
Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations
|
|
|
|
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta.
|
|
|
|
By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval.
|
|
|
|
Uses the Python *lifelines* and *matplotlib* libraries.
|
|
|
|
:param df: Data to generate plot for
|
|
:type df: DataFrame
|
|
:param time_left: Column in *df* for the time to event, left interval endpoint (numeric)
|
|
:type time_left: str
|
|
:param time_right: Column in *df* for the time to event, right interval endpoint (numeric)
|
|
:type time_right: str
|
|
:param by: Column in *df* to stratify by (categorical)
|
|
:type by: str
|
|
:param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint (numeric)
|
|
:type step_loc: float
|
|
:param transform_x: Function to transform x axis by
|
|
:type transform_x: callable
|
|
:param transform_y: Function to transform y axis by
|
|
:type transform_y: callable
|
|
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
|
:type nan_policy: str
|
|
|
|
:rtype: (Figure, Axes)
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Check for/clean NaNs
|
|
if by:
|
|
df = check_nan(df[[time_left, time_right, by]], nan_policy)
|
|
else:
|
|
df = check_nan(df[[time_left, time_right]], nan_policy)
|
|
|
|
fig, ax = plt.subplots()
|
|
|
|
if by is not None:
|
|
# Group by independent variable
|
|
groups = df.groupby(by)
|
|
|
|
for group in groups.groups:
|
|
subset = groups.get_group(group)
|
|
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], step_loc, transform_x, transform_y)
|
|
handle.set_label('{} = {}'.format(by, group))
|
|
else:
|
|
# No grouping
|
|
plot_survfunc_turnbull(ax, df[time_left], df[time_right], step_loc, transform_x, transform_y)
|
|
|
|
ax.set_xlabel('Analysis time')
|
|
ax.set_ylabel('Survival probability')
|
|
ax.set_xlim(left=0)
|
|
ax.set_ylim(0, 1)
|
|
ax.legend()
|
|
|
|
return fig, ax
|
|
|
|
def plot_survfunc_turnbull(ax, time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
|
|
xpoints, ypoints = calc_survfunc_turnbull(time_left, time_right, step_loc, transform_x, transform_y)
|
|
handle = ax.plot(xpoints, ypoints)[0]
|
|
return handle
|
|
|
|
def calc_survfunc_turnbull(time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
|
|
from lifelines.fitters.npmle import npmle
|
|
|
|
EPSILON = 1e-10
|
|
|
|
# TODO: Support left == right => failure was exactly observed
|
|
|
|
followup_left = time_left + EPSILON # Add epsilon to make interval half-open
|
|
followup_right = time_right
|
|
|
|
# Estimate the survival function
|
|
#sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
|
|
|
|
# Call lifelines.fitters.npmle.npmle directly so we can compute midpoints, etc.
|
|
sf_probs, turnbull_intervals = npmle(followup_left, followup_right)
|
|
|
|
xpoints = [i.left*(1-step_loc) + i.right*step_loc for i in turnbull_intervals]
|
|
ypoints = 1 - np.cumsum(sf_probs)
|
|
|
|
# Draw straight lines
|
|
# np.concatenate(...) to force starting drawing from time 0, survival 100%
|
|
xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:]
|
|
ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1]
|
|
|
|
if transform_x:
|
|
xpoints = transform_x(xpoints)
|
|
if transform_y:
|
|
ypoints = transform_y(ypoints)
|
|
|
|
return xpoints, ypoints
|
|
|
|
def survtime_to_numeric(df, time):
|
|
"""
|
|
Convert pandas timedelta dtype to float64, auto-detecting the best time unit to display
|
|
|
|
:param df: Data to check for pandas timedelta dtype
|
|
:type df: DataFrame
|
|
:param time: Column to check for pandas timedelta dtype
|
|
:type df: DataFrame
|
|
|
|
:return: (*df*, *time_units*)
|
|
|
|
* **df** (*DataFrame*) – Data with pandas timedelta dtypes converted, which is *not* copied
|
|
* **time_units** (*str*) – Human-readable description of the time unit, or *None* if not converted
|
|
"""
|
|
|
|
if df[time].dtype == '<m8[ns]':
|
|
df[time] = df[time].dt.total_seconds()
|
|
|
|
# Auto-detect best time units
|
|
if df[time].max() > 365.24*24*60*60:
|
|
df[time] = df[time] / (365.24*24*60*60)
|
|
time_units = 'years'
|
|
elif df[time].max() > 7*24*60*60 / 12:
|
|
df[time] = df[time] / (7*24*60*60)
|
|
time_units = 'weeks'
|
|
elif df[time].max() > 24*60*60:
|
|
df[time] = df[time] / (24*60*60)
|
|
time_units = 'days'
|
|
elif df[time].max() > 60*60:
|
|
df[time] = df[time] / (60*60)
|
|
time_units = 'hours'
|
|
elif df[time].max() > 60:
|
|
df[time] = df[time] / 60
|
|
time_units = 'minutes'
|
|
else:
|
|
time_units = 'seconds'
|
|
|
|
return df, time_units
|
|
else:
|
|
return df, None
|
|
|
|
def logrank(df, time, status, by, nan_policy='warn'):
|
|
"""
|
|
Perform the log-rank test for equality of survival functions
|
|
|
|
:param df: Data to perform the test on
|
|
:type df: DataFrame
|
|
:param time: Column in *df* for the time to event (numeric or timedelta)
|
|
:type time: str
|
|
:param status: Column in *df* for the status variable (True/False or 1/0)
|
|
:type status: str
|
|
:param by: Column in *df* to stratify by (categorical)
|
|
:type by: str
|
|
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
|
:type nan_policy: str
|
|
|
|
:rtype: :class:`yli.sig_tests.ChiSquaredResult`
|
|
"""
|
|
|
|
# TODO: Example
|
|
|
|
# Check for/clean NaNs
|
|
df = check_nan(df[[time, status, by]], nan_policy)
|
|
|
|
if df[time].dtype == '<m8[ns]':
|
|
df[time] = df[time].dt.total_seconds()
|
|
|
|
statistic, pvalue = sm.duration.survdiff(df[time], df[status], df[by])
|
|
|
|
return ChiSquaredResult(statistic=statistic, dof=1, pvalue=pvalue)
|