diff --git a/yli/survival.py b/yli/survival.py index 710caf9..9b98827 100644 --- a/yli/survival.py +++ b/yli/survival.py @@ -118,13 +118,13 @@ def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_ return handle -def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_y=None, nan_policy='warn'): +def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'): """ Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta. - For ease of interpretation, the survival function is drawn as a step function at the midpoint of the estimate on each interval. + By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval. Uses the Python *lifelines* and *matplotlib* libraries. @@ -136,6 +136,8 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_ :type time_right: str :param by: Column in *df* to stratify by (categorical) :type by: str + :param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint (numeric) + :type step_loc: float :param transform_x: Function to transform x axis by :type transform_x: callable :param transform_y: Function to transform y axis by @@ -162,11 +164,11 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_ for group in groups.groups: subset = groups.get_group(group) - handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], transform_x, transform_y) + handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], step_loc, transform_x, transform_y) handle.set_label('{} = {}'.format(by, group)) else: # No grouping - plot_survfunc_turnbull(ax, df[time_left], df[time_right], transform_x, transform_y) + plot_survfunc_turnbull(ax, df[time_left], df[time_right], step_loc, transform_x, transform_y) ax.set_xlabel('Analysis time') ax.set_ylabel('Survival probability') @@ -176,8 +178,13 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_ return fig, ax -def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transform_y=None): - import lifelines +def plot_survfunc_turnbull(ax, time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None): + xpoints, ypoints = calc_survfunc_turnbull(time_left, time_right, step_loc, transform_x, transform_y) + handle = ax.plot(xpoints, ypoints)[0] + return handle + +def calc_survfunc_turnbull(time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None): + from lifelines.fitters.npmle import npmle EPSILON = 1e-10 @@ -187,21 +194,25 @@ def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transfor followup_right = time_right # Estimate the survival function - sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right) + #sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right) + + # Call lifelines.fitters.npmle.npmle directly so we can compute midpoints, etc. + sf_probs, turnbull_intervals = npmle(followup_left, followup_right) + + xpoints = [i.left*(1-step_loc) + i.right*step_loc for i in turnbull_intervals] + ypoints = 1 - np.cumsum(sf_probs) # Draw straight lines - xpoints = sf.survival_function_.index.to_numpy().repeat(2)[:-1] - med = (sf.survival_function_['NPMLE_estimate_upper'] + sf.survival_function_['NPMLE_estimate_lower']) / 2 - ypoints = med.to_numpy().repeat(2)[1:] + # np.concatenate(...) to force starting drawing from time 0, survival 100% + xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:] + ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1] if transform_x: xpoints = transform_x(xpoints) if transform_y: ypoints = transform_y(ypoints) - handle = ax.plot(xpoints, ypoints)[0] - - return handle + return xpoints, ypoints def survtime_to_numeric(df, time): """