Fix up yli.turnbull

Draw step function at configurable point on each Turnbull interval (previous documentation did not correctly describe the behaviour of the function)
Draw survival curve from time 0, survival 100%
This commit is contained in:
RunasSudo 2023-04-22 00:43:01 +10:00
parent 8899d1c968
commit 8238383edb
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 24 additions and 13 deletions

View File

@ -118,13 +118,13 @@ def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_
return handle return handle
def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_y=None, nan_policy='warn'): def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'):
""" """
Generate a Turnbull estimator plot, which extends the KaplanMeier estimator to interval-censored observations Generate a Turnbull estimator plot, which extends the KaplanMeier estimator to interval-censored observations
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta. The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta.
For ease of interpretation, the survival function is drawn as a step function at the midpoint of the estimate on each interval. By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval.
Uses the Python *lifelines* and *matplotlib* libraries. Uses the Python *lifelines* and *matplotlib* libraries.
@ -136,6 +136,8 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
:type time_right: str :type time_right: str
:param by: Column in *df* to stratify by (categorical) :param by: Column in *df* to stratify by (categorical)
:type by: str :type by: str
:param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint (numeric)
:type step_loc: float
:param transform_x: Function to transform x axis by :param transform_x: Function to transform x axis by
:type transform_x: callable :type transform_x: callable
:param transform_y: Function to transform y axis by :param transform_y: Function to transform y axis by
@ -162,11 +164,11 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
for group in groups.groups: for group in groups.groups:
subset = groups.get_group(group) subset = groups.get_group(group)
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], transform_x, transform_y) handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], step_loc, transform_x, transform_y)
handle.set_label('{} = {}'.format(by, group)) handle.set_label('{} = {}'.format(by, group))
else: else:
# No grouping # No grouping
plot_survfunc_turnbull(ax, df[time_left], df[time_right], transform_x, transform_y) plot_survfunc_turnbull(ax, df[time_left], df[time_right], step_loc, transform_x, transform_y)
ax.set_xlabel('Analysis time') ax.set_xlabel('Analysis time')
ax.set_ylabel('Survival probability') ax.set_ylabel('Survival probability')
@ -176,8 +178,13 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
return fig, ax return fig, ax
def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transform_y=None): def plot_survfunc_turnbull(ax, time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
import lifelines xpoints, ypoints = calc_survfunc_turnbull(time_left, time_right, step_loc, transform_x, transform_y)
handle = ax.plot(xpoints, ypoints)[0]
return handle
def calc_survfunc_turnbull(time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
from lifelines.fitters.npmle import npmle
EPSILON = 1e-10 EPSILON = 1e-10
@ -187,21 +194,25 @@ def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transfor
followup_right = time_right followup_right = time_right
# Estimate the survival function # Estimate the survival function
sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right) #sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
# Call lifelines.fitters.npmle.npmle directly so we can compute midpoints, etc.
sf_probs, turnbull_intervals = npmle(followup_left, followup_right)
xpoints = [i.left*(1-step_loc) + i.right*step_loc for i in turnbull_intervals]
ypoints = 1 - np.cumsum(sf_probs)
# Draw straight lines # Draw straight lines
xpoints = sf.survival_function_.index.to_numpy().repeat(2)[:-1] # np.concatenate(...) to force starting drawing from time 0, survival 100%
med = (sf.survival_function_['NPMLE_estimate_upper'] + sf.survival_function_['NPMLE_estimate_lower']) / 2 xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:]
ypoints = med.to_numpy().repeat(2)[1:] ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1]
if transform_x: if transform_x:
xpoints = transform_x(xpoints) xpoints = transform_x(xpoints)
if transform_y: if transform_y:
ypoints = transform_y(ypoints) ypoints = transform_y(ypoints)
handle = ax.plot(xpoints, ypoints)[0] return xpoints, ypoints
return handle
def survtime_to_numeric(df, time): def survtime_to_numeric(df, time):
""" """