Implement yli.turnbull
This commit is contained in:
parent
d359820f42
commit
e12dd65fdc
@ -71,6 +71,7 @@ Relevant statistical functions are all directly available from the top-level *yl
|
||||
* Survival analysis:
|
||||
* *kaplanmeier*: Kaplan–Meier plot
|
||||
* *logrank*: Log-rank test
|
||||
* *turnbull*: Turnbull estimator plot for interval-censored data
|
||||
* Input/output:
|
||||
* *pickle_write_compressed*, *pickle_read_compressed*: Pickle a pandas DataFrame and compress using LZMA
|
||||
* *pickle_write_encrypted*, *pickle_read_encrypted*: Pickle a pandas DataFrame, compress using LZMA, and encrypt
|
||||
|
@ -7,3 +7,5 @@ Functions
|
||||
.. autofunction:: yli.kaplanmeier
|
||||
|
||||
.. autofunction:: yli.logrank
|
||||
|
||||
.. autofunction:: yli.turnbull
|
||||
|
@ -22,7 +22,7 @@ from .graphs import init_fonts
|
||||
from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
|
||||
from .regress import OrdinalLogit, PenalisedLogit, logit_then_regress, regress, vif
|
||||
from .sig_tests import anova_oneway, auto_univariable, chi2, mannwhitney, pearsonr, spearman, ttest_ind
|
||||
from .survival import kaplanmeier, logrank
|
||||
from .survival import kaplanmeier, logrank, turnbull
|
||||
from .utils import as_ordinal
|
||||
|
||||
def reload_me():
|
||||
|
219
yli/survival.py
219
yli/survival.py
@ -21,7 +21,7 @@ from .config import config
|
||||
from .sig_tests import ChiSquaredResult
|
||||
from .utils import check_nan
|
||||
|
||||
def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'):
|
||||
def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transform_y=None, nan_policy='warn'):
|
||||
"""
|
||||
Generate a Kaplan–Meier plot
|
||||
|
||||
@ -37,6 +37,10 @@ def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'):
|
||||
:type by: str
|
||||
:param ci: Whether to plot confidence intervals around the survival function
|
||||
:type ci: bool
|
||||
:param transform_x: Function to transform x axis by
|
||||
:type transform_x: callable
|
||||
:param transform_y: Function to transform y axis by
|
||||
:type transform_y: callable
|
||||
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
||||
:type nan_policy: str
|
||||
|
||||
@ -51,6 +55,167 @@ def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'):
|
||||
else:
|
||||
df = check_nan(df[[time, status]], nan_policy)
|
||||
|
||||
# Covert timedelta to numeric
|
||||
df, time_units = survtime_to_numeric(df, time)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
if by is not None:
|
||||
# Group by independent variable
|
||||
groups = df.groupby(by)
|
||||
|
||||
for group in groups.groups:
|
||||
subset = groups.get_group(group)
|
||||
handle = plot_survfunc_kaplanmeier(ax, subset[time], subset[status], ci, transform_x, transform_y)
|
||||
handle.set_label('{} = {}'.format(by, group))
|
||||
else:
|
||||
# No grouping
|
||||
plot_survfunc_kaplanmeier(ax, df[time], df[status], ci, transform_x, transform_y)
|
||||
|
||||
if time_units:
|
||||
ax.set_xlabel('{} ({})'.format(time, time_units))
|
||||
else:
|
||||
ax.set_xlabel(time)
|
||||
ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
|
||||
ax.set_xlim(left=0)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend()
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_y=None):
|
||||
# Estimate the survival function
|
||||
sf = sm.SurvfuncRight(time, status)
|
||||
|
||||
# Draw straight lines
|
||||
xpoints = sf.surv_times.repeat(2)[1:]
|
||||
ypoints = sf.surv_prob.repeat(2)[:-1]
|
||||
handle = ax.plot(xpoints, ypoints)[0]
|
||||
|
||||
if transform_x:
|
||||
xpoints = transform_x(xpoints)
|
||||
if transform_y:
|
||||
ypoints = transform_y(ypoints)
|
||||
|
||||
if ci:
|
||||
zstar = -stats.norm.ppf(config.alpha/2)
|
||||
|
||||
# Get confidence intervals
|
||||
ci0 = sf.surv_prob - zstar * sf.surv_prob_se
|
||||
ci1 = sf.surv_prob + zstar * sf.surv_prob_se
|
||||
|
||||
# Plot confidence intervals
|
||||
ypoints0 = ci0.repeat(2)[:-1]
|
||||
ypoints1 = ci1.repeat(2)[:-1]
|
||||
|
||||
if transform_y:
|
||||
ypoints0 = transform_y(ypoints0)
|
||||
ypoints1 = transform_y(ypoints1)
|
||||
|
||||
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
|
||||
|
||||
return handle
|
||||
|
||||
def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_y=None, nan_policy='warn'):
|
||||
"""
|
||||
Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations
|
||||
|
||||
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta.
|
||||
|
||||
For ease of interpretation, the survival function is drawn as a step function at the midpoint of the estimate on each interval.
|
||||
|
||||
Uses the Python *lifelines* and *matplotlib* libraries.
|
||||
|
||||
:param df: Data to generate plot for
|
||||
:type df: DataFrame
|
||||
:param time_left: Column in *df* for the time to event, left interval endpoint (numeric)
|
||||
:type time_left: str
|
||||
:param time_right: Column in *df* for the time to event, right interval endpoint (numeric)
|
||||
:type time_right: str
|
||||
:param by: Column in *df* to stratify by (categorical)
|
||||
:type by: str
|
||||
:param transform_x: Function to transform x axis by
|
||||
:type transform_x: callable
|
||||
:param transform_y: Function to transform y axis by
|
||||
:type transform_y: callable
|
||||
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
||||
:type nan_policy: str
|
||||
|
||||
:rtype: (Figure, Axes)
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Check for/clean NaNs
|
||||
if by:
|
||||
df = check_nan(df[[time_left, time_right, by]], nan_policy)
|
||||
else:
|
||||
df = check_nan(df[[time_left, time_right]], nan_policy)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
if by is not None:
|
||||
# Group by independent variable
|
||||
groups = df.groupby(by)
|
||||
|
||||
for group in groups.groups:
|
||||
subset = groups.get_group(group)
|
||||
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], transform_x, transform_y)
|
||||
handle.set_label('{} = {}'.format(by, group))
|
||||
else:
|
||||
# No grouping
|
||||
plot_survfunc_turnbull(ax, df[time_left], df[time_right], transform_x, transform_y)
|
||||
|
||||
ax.set_xlabel('Analysis time')
|
||||
ax.set_ylabel('Survival probability')
|
||||
ax.set_xlim(left=0)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend()
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transform_y=None):
|
||||
import lifelines
|
||||
|
||||
EPSILON = 1e-10
|
||||
|
||||
# TODO: Support left == right => failure was exactly observed
|
||||
|
||||
followup_left = time_left + EPSILON # Add epsilon to make interval half-open
|
||||
followup_right = time_right
|
||||
|
||||
# Estimate the survival function
|
||||
sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
|
||||
|
||||
# Draw straight lines
|
||||
xpoints = sf.survival_function_.index.to_numpy().repeat(2)[:-1]
|
||||
med = (sf.survival_function_['NPMLE_estimate_upper'] + sf.survival_function_['NPMLE_estimate_lower']) / 2
|
||||
ypoints = med.to_numpy().repeat(2)[1:]
|
||||
|
||||
if transform_x:
|
||||
xpoints = transform_x(xpoints)
|
||||
if transform_y:
|
||||
ypoints = transform_y(ypoints)
|
||||
|
||||
handle = ax.plot(xpoints, ypoints)[0]
|
||||
|
||||
return handle
|
||||
|
||||
def survtime_to_numeric(df, time):
|
||||
"""
|
||||
Convert pandas timedelta dtype to float64, auto-detecting the best time unit to display
|
||||
|
||||
:param df: Data to check for pandas timedelta dtype
|
||||
:type df: DataFrame
|
||||
:param time: Column to check for pandas timedelta dtype
|
||||
:type df: DataFrame
|
||||
|
||||
:return: (*df*, *time_units*)
|
||||
|
||||
* **df** (*DataFrame*) – Data with pandas timedelta dtypes converted, which is *not* copied
|
||||
* **time_units** (*str*) – Human-readable description of the time unit, or *None* if not converted
|
||||
"""
|
||||
|
||||
if df[time].dtype == '<m8[ns]':
|
||||
df[time] = df[time].dt.total_seconds()
|
||||
|
||||
@ -72,56 +237,10 @@ def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'):
|
||||
time_units = 'minutes'
|
||||
else:
|
||||
time_units = 'seconds'
|
||||
|
||||
return df, time_units
|
||||
else:
|
||||
time_units = None
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
if by is not None:
|
||||
# Group by independent variable
|
||||
groups = df.groupby(by)
|
||||
|
||||
for group in groups.groups:
|
||||
subset = groups.get_group(group)
|
||||
handle = plot_survfunc(ax, subset[time], subset[status], ci)
|
||||
handle.set_label('{} = {}'.format(by, group))
|
||||
else:
|
||||
# No grouping
|
||||
plot_survfunc(ax, df[time], df[status], ci)
|
||||
|
||||
if time_units:
|
||||
ax.set_xlabel('{} ({})'.format(time, time_units))
|
||||
else:
|
||||
ax.set_xlabel(time)
|
||||
ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend()
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_survfunc(ax, time, status, ci):
|
||||
# Estimate the survival function
|
||||
sf = sm.SurvfuncRight(time, status)
|
||||
|
||||
# Draw straight lines
|
||||
xpoints = sf.surv_times.repeat(2)[1:]
|
||||
ypoints = sf.surv_prob.repeat(2)[:-1]
|
||||
handle = ax.plot(xpoints, ypoints)[0]
|
||||
|
||||
if ci:
|
||||
zstar = -stats.norm.ppf(config.alpha/2)
|
||||
|
||||
# Get confidence intervals
|
||||
ci0 = sf.surv_prob - zstar * sf.surv_prob_se
|
||||
ci1 = sf.surv_prob + zstar * sf.surv_prob_se
|
||||
|
||||
# Plot confidence intervals
|
||||
ypoints0 = ci0.repeat(2)[:-1]
|
||||
ypoints1 = ci1.repeat(2)[:-1]
|
||||
|
||||
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
|
||||
|
||||
return handle
|
||||
return df, None
|
||||
|
||||
def logrank(df, time, status, by, nan_policy='warn'):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user