Fix up yli.turnbull

Draw step function at configurable point on each Turnbull interval (previous documentation did not correctly describe the behaviour of the function)
Draw survival curve from time 0, survival 100%
This commit is contained in:
RunasSudo 2023-04-22 00:43:01 +10:00
parent 8899d1c968
commit 8238383edb
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 24 additions and 13 deletions

View File

@ -118,13 +118,13 @@ def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_
return handle
def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_y=None, nan_policy='warn'):
def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'):
"""
Generate a Turnbull estimator plot, which extends the KaplanMeier estimator to interval-censored observations
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta.
For ease of interpretation, the survival function is drawn as a step function at the midpoint of the estimate on each interval.
By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval.
Uses the Python *lifelines* and *matplotlib* libraries.
@ -136,6 +136,8 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
:type time_right: str
:param by: Column in *df* to stratify by (categorical)
:type by: str
:param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint (numeric)
:type step_loc: float
:param transform_x: Function to transform x axis by
:type transform_x: callable
:param transform_y: Function to transform y axis by
@ -162,11 +164,11 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
for group in groups.groups:
subset = groups.get_group(group)
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], transform_x, transform_y)
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], step_loc, transform_x, transform_y)
handle.set_label('{} = {}'.format(by, group))
else:
# No grouping
plot_survfunc_turnbull(ax, df[time_left], df[time_right], transform_x, transform_y)
plot_survfunc_turnbull(ax, df[time_left], df[time_right], step_loc, transform_x, transform_y)
ax.set_xlabel('Analysis time')
ax.set_ylabel('Survival probability')
@ -176,8 +178,13 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
return fig, ax
def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transform_y=None):
import lifelines
def plot_survfunc_turnbull(ax, time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
xpoints, ypoints = calc_survfunc_turnbull(time_left, time_right, step_loc, transform_x, transform_y)
handle = ax.plot(xpoints, ypoints)[0]
return handle
def calc_survfunc_turnbull(time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
from lifelines.fitters.npmle import npmle
EPSILON = 1e-10
@ -187,21 +194,25 @@ def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transfor
followup_right = time_right
# Estimate the survival function
sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
#sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
# Call lifelines.fitters.npmle.npmle directly so we can compute midpoints, etc.
sf_probs, turnbull_intervals = npmle(followup_left, followup_right)
xpoints = [i.left*(1-step_loc) + i.right*step_loc for i in turnbull_intervals]
ypoints = 1 - np.cumsum(sf_probs)
# Draw straight lines
xpoints = sf.survival_function_.index.to_numpy().repeat(2)[:-1]
med = (sf.survival_function_['NPMLE_estimate_upper'] + sf.survival_function_['NPMLE_estimate_lower']) / 2
ypoints = med.to_numpy().repeat(2)[1:]
# np.concatenate(...) to force starting drawing from time 0, survival 100%
xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:]
ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1]
if transform_x:
xpoints = transform_x(xpoints)
if transform_y:
ypoints = transform_y(ypoints)
handle = ax.plot(xpoints, ypoints)[0]
return handle
return xpoints, ypoints
def survtime_to_numeric(df, time):
"""