From 8238383edb93cbd72c6aa8812bddfe33bff95fe9 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sat, 22 Apr 2023 00:43:01 +1000 Subject: [PATCH] Fix up yli.turnbull Draw step function at configurable point on each Turnbull interval (previous documentation did not correctly describe the behaviour of the function) Draw survival curve from time 0, survival 100% --- yli/survival.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/yli/survival.py b/yli/survival.py index 710caf9..9b98827 100644 --- a/yli/survival.py +++ b/yli/survival.py @@ -118,13 +118,13 @@ def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_ return handle -def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_y=None, nan_policy='warn'): +def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'): """ Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta. - For ease of interpretation, the survival function is drawn as a step function at the midpoint of the estimate on each interval. + By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval. Uses the Python *lifelines* and *matplotlib* libraries. @@ -136,6 +136,8 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_ :type time_right: str :param by: Column in *df* to stratify by (categorical) :type by: str + :param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint (numeric) + :type step_loc: float :param transform_x: Function to transform x axis by :type transform_x: callable :param transform_y: Function to transform y axis by @@ -162,11 +164,11 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_ for group in groups.groups: subset = groups.get_group(group) - handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], transform_x, transform_y) + handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], step_loc, transform_x, transform_y) handle.set_label('{} = {}'.format(by, group)) else: # No grouping - plot_survfunc_turnbull(ax, df[time_left], df[time_right], transform_x, transform_y) + plot_survfunc_turnbull(ax, df[time_left], df[time_right], step_loc, transform_x, transform_y) ax.set_xlabel('Analysis time') ax.set_ylabel('Survival probability') @@ -176,8 +178,13 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_ return fig, ax -def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transform_y=None): - import lifelines +def plot_survfunc_turnbull(ax, time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None): + xpoints, ypoints = calc_survfunc_turnbull(time_left, time_right, step_loc, transform_x, transform_y) + handle = ax.plot(xpoints, ypoints)[0] + return handle + +def calc_survfunc_turnbull(time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None): + from lifelines.fitters.npmle import npmle EPSILON = 1e-10 @@ -187,21 +194,25 @@ def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transfor followup_right = time_right # Estimate the survival function - sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right) + #sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right) + + # Call lifelines.fitters.npmle.npmle directly so we can compute midpoints, etc. + sf_probs, turnbull_intervals = npmle(followup_left, followup_right) + + xpoints = [i.left*(1-step_loc) + i.right*step_loc for i in turnbull_intervals] + ypoints = 1 - np.cumsum(sf_probs) # Draw straight lines - xpoints = sf.survival_function_.index.to_numpy().repeat(2)[:-1] - med = (sf.survival_function_['NPMLE_estimate_upper'] + sf.survival_function_['NPMLE_estimate_lower']) / 2 - ypoints = med.to_numpy().repeat(2)[1:] + # np.concatenate(...) to force starting drawing from time 0, survival 100% + xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:] + ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1] if transform_x: xpoints = transform_x(xpoints) if transform_y: ypoints = transform_y(ypoints) - handle = ax.plot(xpoints, ypoints)[0] - - return handle + return xpoints, ypoints def survtime_to_numeric(df, time): """