Fix up yli.turnbull
Draw step function at configurable point on each Turnbull interval (previous documentation did not correctly describe the behaviour of the function) Draw survival curve from time 0, survival 100%
This commit is contained in:
parent
8899d1c968
commit
8238383edb
@ -118,13 +118,13 @@ def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_
|
|||||||
|
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_y=None, nan_policy='warn'):
|
def turnbull(df, time_left, time_right, by=None, *, step_loc=0.5, transform_x=None, transform_y=None, nan_policy='warn'):
|
||||||
"""
|
"""
|
||||||
Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations
|
Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations
|
||||||
|
|
||||||
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta.
|
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored. Unlike :func:`yli.kaplanmeier`, times must be given as numeric dtypes and not as pandas timedelta.
|
||||||
|
|
||||||
For ease of interpretation, the survival function is drawn as a step function at the midpoint of the estimate on each interval.
|
By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval.
|
||||||
|
|
||||||
Uses the Python *lifelines* and *matplotlib* libraries.
|
Uses the Python *lifelines* and *matplotlib* libraries.
|
||||||
|
|
||||||
@ -136,6 +136,8 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
|
|||||||
:type time_right: str
|
:type time_right: str
|
||||||
:param by: Column in *df* to stratify by (categorical)
|
:param by: Column in *df* to stratify by (categorical)
|
||||||
:type by: str
|
:type by: str
|
||||||
|
:param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint (numeric)
|
||||||
|
:type step_loc: float
|
||||||
:param transform_x: Function to transform x axis by
|
:param transform_x: Function to transform x axis by
|
||||||
:type transform_x: callable
|
:type transform_x: callable
|
||||||
:param transform_y: Function to transform y axis by
|
:param transform_y: Function to transform y axis by
|
||||||
@ -162,11 +164,11 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
|
|||||||
|
|
||||||
for group in groups.groups:
|
for group in groups.groups:
|
||||||
subset = groups.get_group(group)
|
subset = groups.get_group(group)
|
||||||
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], transform_x, transform_y)
|
handle = plot_survfunc_turnbull(ax, subset[time_left], subset[time_right], step_loc, transform_x, transform_y)
|
||||||
handle.set_label('{} = {}'.format(by, group))
|
handle.set_label('{} = {}'.format(by, group))
|
||||||
else:
|
else:
|
||||||
# No grouping
|
# No grouping
|
||||||
plot_survfunc_turnbull(ax, df[time_left], df[time_right], transform_x, transform_y)
|
plot_survfunc_turnbull(ax, df[time_left], df[time_right], step_loc, transform_x, transform_y)
|
||||||
|
|
||||||
ax.set_xlabel('Analysis time')
|
ax.set_xlabel('Analysis time')
|
||||||
ax.set_ylabel('Survival probability')
|
ax.set_ylabel('Survival probability')
|
||||||
@ -176,8 +178,13 @@ def turnbull(df, time_left, time_right, by=None, *, transform_x=None, transform_
|
|||||||
|
|
||||||
return fig, ax
|
return fig, ax
|
||||||
|
|
||||||
def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transform_y=None):
|
def plot_survfunc_turnbull(ax, time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
|
||||||
import lifelines
|
xpoints, ypoints = calc_survfunc_turnbull(time_left, time_right, step_loc, transform_x, transform_y)
|
||||||
|
handle = ax.plot(xpoints, ypoints)[0]
|
||||||
|
return handle
|
||||||
|
|
||||||
|
def calc_survfunc_turnbull(time_left, time_right, step_loc=0.5, transform_x=None, transform_y=None):
|
||||||
|
from lifelines.fitters.npmle import npmle
|
||||||
|
|
||||||
EPSILON = 1e-10
|
EPSILON = 1e-10
|
||||||
|
|
||||||
@ -187,21 +194,25 @@ def plot_survfunc_turnbull(ax, time_left, time_right, transform_x=None, transfor
|
|||||||
followup_right = time_right
|
followup_right = time_right
|
||||||
|
|
||||||
# Estimate the survival function
|
# Estimate the survival function
|
||||||
sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
|
#sf = lifelines.KaplanMeierFitter().fit_interval_censoring(followup_left, followup_right)
|
||||||
|
|
||||||
|
# Call lifelines.fitters.npmle.npmle directly so we can compute midpoints, etc.
|
||||||
|
sf_probs, turnbull_intervals = npmle(followup_left, followup_right)
|
||||||
|
|
||||||
|
xpoints = [i.left*(1-step_loc) + i.right*step_loc for i in turnbull_intervals]
|
||||||
|
ypoints = 1 - np.cumsum(sf_probs)
|
||||||
|
|
||||||
# Draw straight lines
|
# Draw straight lines
|
||||||
xpoints = sf.survival_function_.index.to_numpy().repeat(2)[:-1]
|
# np.concatenate(...) to force starting drawing from time 0, survival 100%
|
||||||
med = (sf.survival_function_['NPMLE_estimate_upper'] + sf.survival_function_['NPMLE_estimate_lower']) / 2
|
xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:]
|
||||||
ypoints = med.to_numpy().repeat(2)[1:]
|
ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1]
|
||||||
|
|
||||||
if transform_x:
|
if transform_x:
|
||||||
xpoints = transform_x(xpoints)
|
xpoints = transform_x(xpoints)
|
||||||
if transform_y:
|
if transform_y:
|
||||||
ypoints = transform_y(ypoints)
|
ypoints = transform_y(ypoints)
|
||||||
|
|
||||||
handle = ax.plot(xpoints, ypoints)[0]
|
return xpoints, ypoints
|
||||||
|
|
||||||
return handle
|
|
||||||
|
|
||||||
def survtime_to_numeric(df, time):
|
def survtime_to_numeric(df, time):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user