diff --git a/README.md b/README.md index 23162e2..c4036f3 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ Relevant statistical functions are all directly available from the top-level *yl * Survival analysis: * *kaplanmeier*: Kaplan–Meier plot * *logrank*: Log-rank test + * *turnbull*: Turnbull estimator plot for interval-censored data * Input/output: * *pickle_write_compressed*, *pickle_read_compressed*: Pickle a pandas DataFrame and compress using LZMA * *pickle_write_encrypted*, *pickle_read_encrypted*: Pickle a pandas DataFrame, compress using LZMA, and encrypt diff --git a/docs/survival.rst b/docs/survival.rst index 219fe6b..60b631d 100644 --- a/docs/survival.rst +++ b/docs/survival.rst @@ -7,3 +7,5 @@ Functions .. autofunction:: yli.kaplanmeier .. autofunction:: yli.logrank + +.. autofunction:: yli.turnbull diff --git a/yli/__init__.py b/yli/__init__.py index e9e4c66..f738b32 100644 --- a/yli/__init__.py +++ b/yli/__init__.py @@ -22,7 +22,7 @@ from .graphs import init_fonts from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted from .regress import OrdinalLogit, PenalisedLogit, logit_then_regress, regress, vif from .sig_tests import anova_oneway, auto_univariable, chi2, mannwhitney, pearsonr, spearman, ttest_ind -from .survival import kaplanmeier, logrank +from .survival import kaplanmeier, logrank, turnbull from .utils import as_ordinal def reload_me(): diff --git a/yli/survival.py b/yli/survival.py index 0b70e37..1d48acb 100644 --- a/yli/survival.py +++ b/yli/survival.py @@ -21,7 +21,7 @@ from .config import config from .sig_tests import ChiSquaredResult from .utils import check_nan -def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'): +def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transform_y=None, nan_policy='warn'): """ Generate a Kaplan–Meier plot @@ -37,6 +37,10 @@ def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'): :type by: str :param ci: Whether to plot confidence intervals around the survival function :type ci: bool + :param transform_x: Function to transform x axis by + :type transform_x: callable + :param transform_y: Function to transform y axis by + :type transform_y: callable :param nan_policy: How to handle *nan* values (see :ref:`nan-handling`) :type nan_policy: str @@ -51,6 +55,167 @@ def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'): else: df = check_nan(df[[time, status]], nan_policy) + # Covert timedelta to numeric + df, time_units = survtime_to_numeric(df, time) + + fig, ax = plt.subplots() + + if by is not None: + # Group by independent variable + groups = df.groupby(by) + + for group in groups.groups: + subset = groups.get_group(group) + handle = plot_survfunc_kaplanmeier(ax, subset[time], subset[status], ci, transform_x, transform_y) + handle.set_label('{} = {}'.format(by, group)) + else: + # No grouping + plot_survfunc_kaplanmeier(ax, df[time], df[status], ci, transform_x, transform_y) + + if time_units: + ax.set_xlabel('{} ({})'.format(time, time_units)) + else: + ax.set_xlabel(time) + ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability') + ax.set_xlim(left=0) + ax.set_ylim(0, 1) + ax.legend() + + return fig, ax + +def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_y=None): + # Estimate the survival function + sf = sm.SurvfuncRight(time, status) + + # Draw straight lines + xpoints = sf.surv_times.repeat(2)[1:] + ypoints = sf.surv_prob.repeat(2)[:-1] + handle = ax.plot(xpoints, ypoints)[0] + + if transform_x: + xpoints = transform_x(xpoints) + if transform_y: + ypoints = transform_y(ypoints) + + if ci: + zstar = -stats.norm.ppf(config.alpha/2) + + # Get confidence intervals + ci0 = sf.surv_prob - zstar * sf.surv_prob_se + ci1 = sf.surv_prob + zstar * sf.surv_prob_se + + # Plot confidence intervals + ypoints0 = ci0.repeat(2)[:-1] + ypoints1 = ci1.repeat(2)[:-1] + + if transform_y: + ypoints0 = transform_y(ypoints0) + ypoints1 = transform_y(ypoints1) + + ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_') + + return handle + +def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_y=None, nan_policy='warn'): + """ + Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations + + The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta. + + For ease of interpretation, the survival function is drawn as a step function at the midpoint of the estimate on each interval. + + Uses the Python *lifelines* and *matplotlib* libraries. + + :param df: Data to generate plot for + :type df: DataFrame + :param time_left: Column in *df* for the time to event, left interval endpoint (numeric) + :type time_left: str + :param time_right: Column in *df* for the time to event, right interval endpoint (numeric) + :type time_right: str + :param by: Column in *df* to stratify by (categorical) + :type by: str + :param transform_x: Function to transform x axis by + :type transform_x: callable + :param transform_y: Function to transform y axis by + :type transform_y: callable + :param nan_policy: How to handle *nan* values (see :ref:`nan-handling`) + :type nan_policy: str + + :rtype: (Figure, Axes) + """ + + import matplotlib.pyplot as plt + + # Check for/clean NaNs + if by: + df = check_nan(df[[time_left, time_right, by]], nan_policy) + else: + df = check_nan(df[[time_left, time_right]], nan_policy) + + fig, ax = plt.subplots() + + if by is not None: + # Group by independent variable + groups = df.groupby(by) + + for group in groups.groups: + subset = groups.get_group(group) + handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], transform_x, transform_y) + handle.set_label('{} = {}'.format(by, group)) + else: + # No grouping + plot_survfunc_turnbull(ax, df[time_left], df[time_right], transform_x, transform_y) + + ax.set_xlabel('Analysis time') + ax.set_ylabel('Survival probability') + ax.set_xlim(left=0) + ax.set_ylim(0, 1) + ax.legend() + + return fig, ax + +def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transform_y=None): + import lifelines + + EPSILON = 1e-10 + + # TODO: Support left == right => failure was exactly observed + + followup_left = time_left + EPSILON # Add epsilon to make interval half-open + followup_right = time_right + + # Estimate the survival function + sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right) + + # Draw straight lines + xpoints = sf.survival_function_.index.to_numpy().repeat(2)[:-1] + med = (sf.survival_function_['NPMLE_estimate_upper'] + sf.survival_function_['NPMLE_estimate_lower']) / 2 + ypoints = med.to_numpy().repeat(2)[1:] + + if transform_x: + xpoints = transform_x(xpoints) + if transform_y: + ypoints = transform_y(ypoints) + + handle = ax.plot(xpoints, ypoints)[0] + + return handle + +def survtime_to_numeric(df, time): + """ + Convert pandas timedelta dtype to float64, auto-detecting the best time unit to display + + :param df: Data to check for pandas timedelta dtype + :type df: DataFrame + :param time: Column to check for pandas timedelta dtype + :type df: DataFrame + + :return: (*df*, *time_units*) + + * **df** (*DataFrame*) – Data with pandas timedelta dtypes converted, which is *not* copied + * **time_units** (*str*) – Human-readable description of the time unit, or *None* if not converted + """ + if df[time].dtype == '