408 lines
14 KiB
Python
408 lines
14 KiB
Python
# scipy-yli: Helpful SciPy utilities and recipes
|
|
# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from scipy import stats
|
|
import statsmodels.api as sm
|
|
|
|
import io
|
|
import json
|
|
import subprocess
|
|
|
|
from .config import config
|
|
from .sig_tests import ChiSquaredResult
|
|
from .utils import Estimate, check_nan
|
|
|
|
def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transform_y=None, nan_policy='warn', fig=None, ax=None):
|
|
"""
|
|
Generate a Kaplan–Meier plot
|
|
|
|
Uses the Python *matplotlib* library.
|
|
|
|
:param df: Data to generate plot for
|
|
:type df: DataFrame
|
|
:param time: Column in *df* for the time to event (numeric or timedelta)
|
|
:type time: str
|
|
:param status: Column in *df* for the status variable (True/False or 1/0)
|
|
:type status: str
|
|
:param by: Column in *df* to stratify by (categorical)
|
|
:type by: str
|
|
:param ci: Whether to plot confidence intervals around the survival function
|
|
:type ci: bool
|
|
:param transform_x: Function to transform x axis by
|
|
:type transform_x: callable
|
|
:param transform_y: Function to transform y axis by
|
|
:type transform_y: callable
|
|
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
|
:type nan_policy: str
|
|
|
|
:rtype: (Figure, Axes)
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Check for/clean NaNs
|
|
if by:
|
|
df = check_nan(df[[time, status, by]], nan_policy)
|
|
else:
|
|
df = check_nan(df[[time, status]], nan_policy)
|
|
|
|
# Covert timedelta to numeric
|
|
df, time_units = survtime_to_numeric(df, time)
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots()
|
|
|
|
if by is not None:
|
|
# Group by independent variable
|
|
groups = df.groupby(by)
|
|
|
|
for group in groups.groups:
|
|
subset = groups.get_group(group)
|
|
handle = plot_survfunc_kaplanmeier(ax, subset[time], subset[status], ci, transform_x, transform_y)
|
|
handle.set_label('{} = {}'.format(by, group))
|
|
else:
|
|
# No grouping
|
|
plot_survfunc_kaplanmeier(ax, df[time], df[status], ci, transform_x, transform_y)
|
|
|
|
if time_units:
|
|
ax.set_xlabel('{} ({})'.format(time, time_units))
|
|
else:
|
|
ax.set_xlabel(time)
|
|
ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
|
|
ax.set_xlim(left=0)
|
|
ax.set_ylim(0, 1)
|
|
|
|
if by is not None:
|
|
ax.legend()
|
|
|
|
return fig, ax
|
|
|
|
def plot_survfunc_kaplanmeier(ax, time, status, ci, transform_x=None, transform_y=None):
|
|
xpoints, ypoints, ypoints0, ypoints1 = calc_survfunc_kaplanmeier(time, status, ci, transform_x, transform_y)
|
|
|
|
handle = ax.plot(xpoints, ypoints)[0]
|
|
|
|
if ci:
|
|
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
|
|
|
|
return handle
|
|
|
|
def calc_survfunc_kaplanmeier(time, status, ci, transform_x=None, transform_y=None):
|
|
# Estimate the survival function
|
|
sf = sm.SurvfuncRight(time, status)
|
|
|
|
# Draw straight lines
|
|
# np.concatenate(...) to force starting drawing from time 0, survival 100%
|
|
xpoints = np.concatenate([[0], sf.surv_times]).repeat(2)[1:]
|
|
ypoints = np.concatenate([[1], sf.surv_prob]).repeat(2)[:-1]
|
|
|
|
if transform_x:
|
|
xpoints = transform_x(xpoints)
|
|
if transform_y:
|
|
ypoints = transform_y(ypoints)
|
|
|
|
if ci:
|
|
zstar = -stats.norm.ppf(config.alpha/2)
|
|
|
|
# Get confidence intervals
|
|
ci0 = sf.surv_prob - zstar * sf.surv_prob_se
|
|
ci1 = sf.surv_prob + zstar * sf.surv_prob_se
|
|
|
|
# Plot confidence intervals
|
|
ypoints0 = np.concatenate([[1], ci0]).repeat(2)[:-1]
|
|
ypoints1 = np.concatenate([[1], ci1]).repeat(2)[:-1]
|
|
|
|
if transform_y:
|
|
ypoints0 = transform_y(ypoints0)
|
|
ypoints1 = transform_y(ypoints1)
|
|
|
|
return xpoints, ypoints, ypoints0, ypoints1
|
|
|
|
return xpoints, ypoints, None, None
|
|
|
|
def turnbull(df, time_left, time_right, by=None, *, ci=True, step_loc=0.5, maxiter=None, ll_tolerance=None, se_method=None, zero_tolerance=None, ci_precision=None, transform_x=None, transform_y=None, nan_policy='warn', fig=None, ax=None):
|
|
"""
|
|
Generate a Turnbull estimator plot, which extends the Kaplan–Meier estimator to interval-censored observations
|
|
|
|
The intervals are assumed to be half-open intervals, (*left*, *right*]. *right* == *np.inf* implies the event was right-censored.
|
|
|
|
By default, the survival function is drawn as a step function at the midpoint of each Turnbull interval.
|
|
|
|
Uses the hpstat *turnbull* command.
|
|
|
|
:param df: Data to generate plot for
|
|
:type df: DataFrame
|
|
:param time_left: Column in *df* for the time to event, left interval endpoint (numeric or timedelta)
|
|
:type time_left: str
|
|
:param time_right: Column in *df* for the time to event, right interval endpoint (numeric or timedelta)
|
|
:type time_right: str
|
|
:param by: Column in *df* to stratify by (categorical)
|
|
:type by: str
|
|
:param ci: Whether to plot confidence intervals around the survival function
|
|
:type ci: bool
|
|
:param step_loc: Proportion along the length of each Turnbull interval to step down the survival function, e.g. 0 for left bound, 1 for right bound, 0.5 for interval midpoint
|
|
:type step_loc: float
|
|
:param maxiter: Maximum number of iterations to attempt
|
|
:type maxiter: int
|
|
:param ll_tolerance: Terminate algorithm when the absolute change in log-likelihood is less than this tolerance
|
|
:type ll_tolerance: float
|
|
:param se_method: Method for computing standard error or survival probabilities (see hpstat *turnbull* documentation)
|
|
:type se_method: str
|
|
:param zero_tolerance: Threshold for dropping failure probability when se_method is "oim-drop-zeros"
|
|
:type zero_tolerance: float
|
|
:param ci_precision: Desired precision of confidence limits when se-method is "likelihood-ratio"
|
|
:type ci_precision: float
|
|
:param transform_x: Function to transform x axis by
|
|
:type transform_x: callable
|
|
:param transform_y: Function to transform y axis by
|
|
:type transform_y: callable
|
|
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
|
:type nan_policy: str
|
|
|
|
:rtype: (Figure, Axes)
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Check for/clean NaNs
|
|
if by:
|
|
df = check_nan(df[[time_left, time_right, by]], nan_policy)
|
|
else:
|
|
df = check_nan(df[[time_left, time_right]], nan_policy)
|
|
|
|
# Covert timedelta to numeric
|
|
df, time_units = survtime_to_numeric(df, time_left, time_right)
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots()
|
|
|
|
if by is not None:
|
|
# Group by independent variable
|
|
groups = df.groupby(by)
|
|
|
|
for group in groups.groups:
|
|
subset = groups.get_group(group)
|
|
handle = plot_survfunc_turnbull(
|
|
ax, subset[time_left], subset[time_right],
|
|
ci=ci, step_loc=step_loc, maxiter=maxiter, ll_tolerance=ll_tolerance, se_method=se_method, zero_tolerance=zero_tolerance, ci_precision=ci_precision, transform_x=transform_x, transform_y=transform_y
|
|
)
|
|
handle.set_label('{} = {}'.format(by, group))
|
|
else:
|
|
# No grouping
|
|
plot_survfunc_turnbull(
|
|
ax, df[time_left], df[time_right],
|
|
ci=ci, step_loc=step_loc, maxiter=maxiter, ll_tolerance=ll_tolerance, se_method=se_method, zero_tolerance=zero_tolerance, ci_precision=ci_precision, transform_x=transform_x, transform_y=transform_y
|
|
)
|
|
|
|
if time_units:
|
|
ax.set_xlabel('{} + {} ({})'.format(time_left, time_right, time_units))
|
|
else:
|
|
ax.set_xlabel('{} + {}'.format(time_left, time_right))
|
|
ax.set_ylabel('Survival probability')
|
|
ax.set_xlim(left=0)
|
|
ax.set_ylim(0, 1)
|
|
|
|
if by is not None:
|
|
ax.legend()
|
|
|
|
return fig, ax
|
|
|
|
def plot_survfunc_turnbull(ax, time_left, time_right, *, ci=True, step_loc=0.5, maxiter=None, ll_tolerance=None, se_method=None, zero_tolerance=None, ci_precision=None, transform_x=None, transform_y=None):
|
|
xpoints, ypoints, ypoints0, ypoints1 = calc_survfunc_turnbull(
|
|
time_left, time_right,
|
|
ci=ci, step_loc=step_loc, maxiter=maxiter, ll_tolerance=ll_tolerance, se_method=se_method, zero_tolerance=zero_tolerance, ci_precision=ci_precision, transform_x=transform_x, transform_y=transform_y
|
|
)
|
|
|
|
handle = ax.plot(xpoints, ypoints)[0]
|
|
|
|
if ci:
|
|
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
|
|
|
|
return handle
|
|
|
|
def calc_survfunc_turnbull(time_left, time_right, *, ci=True, step_loc=0.5, maxiter=None, ll_tolerance=None, se_method=None, zero_tolerance=None, ci_precision=None, transform_x=None, transform_y=None):
|
|
# Estimate the survival function
|
|
|
|
# Prepare arguments
|
|
hpstat_args = [config.hpstat_path, 'turnbull', '-', '--output', 'json']
|
|
if maxiter:
|
|
hpstat_args.append('--max-iterations')
|
|
hpstat_args.append(str(maxiter))
|
|
if ll_tolerance:
|
|
hpstat_args.append('--ll-tolerance')
|
|
hpstat_args.append(str(ll_tolerance))
|
|
if se_method:
|
|
hpstat_args.append('--se-method')
|
|
hpstat_args.append(se_method)
|
|
elif not ci:
|
|
hpstat_args.append('--se-method')
|
|
hpstat_args.append('none')
|
|
if zero_tolerance:
|
|
hpstat_args.append('--zero-tolerance')
|
|
hpstat_args.append(str(zero_tolerance))
|
|
if ci_precision:
|
|
hpstat_args.append('--ci-precision')
|
|
hpstat_args.append(str(ci_precision))
|
|
|
|
# Export data to CSV
|
|
csv_buf = io.StringIO()
|
|
pd.DataFrame({'LeftTime': time_left, 'RightTime': time_right}).to_csv(csv_buf, index=False)
|
|
csv_str = csv_buf.getvalue()
|
|
|
|
# Run hpstat binary
|
|
proc = subprocess.run(hpstat_args, input=csv_str, stdout=subprocess.PIPE, stderr=None, encoding='utf-8', check=True)
|
|
raw_result = json.loads(proc.stdout)
|
|
|
|
survival_prob = np.array(raw_result['survival_prob'])
|
|
|
|
from IPython.display import clear_output
|
|
clear_output(wait=True)
|
|
|
|
xpoints = [i[0]*(1-step_loc) + i[1]*step_loc for i in raw_result['failure_intervals'] if i[1]]
|
|
ypoints = survival_prob
|
|
if raw_result['failure_intervals'][-1][1]:
|
|
# No right-censored observations - we can draw the whole survival curve
|
|
ypoints = np.concatenate([ypoints, [0]])
|
|
|
|
# Draw straight lines
|
|
# np.concatenate(...) to force starting drawing from time 0, survival 100%
|
|
xpoints = np.concatenate([[0], xpoints]).repeat(2)[1:]
|
|
ypoints = np.concatenate([[1], ypoints]).repeat(2)[:-1]
|
|
|
|
if transform_x:
|
|
xpoints = transform_x(xpoints)
|
|
if transform_y:
|
|
ypoints = transform_y(ypoints)
|
|
|
|
if ci:
|
|
# Get confidence intervals
|
|
if raw_result['survival_prob_se']:
|
|
zstar = -stats.norm.ppf(config.alpha/2)
|
|
survival_prob_se = np.array(raw_result['survival_prob_se'])
|
|
|
|
ci0 = survival_prob - zstar * survival_prob_se
|
|
ci1 = survival_prob + zstar * survival_prob_se
|
|
else:
|
|
survival_prob_ci = np.array(raw_result['survival_prob_ci'])
|
|
ci0 = survival_prob_ci.T[0]
|
|
ci1 = survival_prob_ci.T[1]
|
|
|
|
if raw_result['failure_intervals'][-1][1]:
|
|
# No right-censored observations - we can draw the whole survival curve
|
|
ci0 = np.concatenate([ci0, [0]])
|
|
ci1 = np.concatenate([ci1, [0]])
|
|
|
|
# Plot confidence intervals
|
|
ypoints0 = np.concatenate([[1], ci0]).repeat(2)[:-1]
|
|
ypoints1 = np.concatenate([[1], ci1]).repeat(2)[:-1]
|
|
|
|
if transform_y:
|
|
ypoints0 = transform_y(ypoints0)
|
|
ypoints1 = transform_y(ypoints1)
|
|
|
|
return xpoints, ypoints, ypoints0, ypoints1
|
|
|
|
return xpoints, ypoints, None, None
|
|
|
|
def survtime_to_numeric(df, time, time2=None):
|
|
"""
|
|
Convert pandas timedelta dtype to float64, auto-detecting the best time unit to display
|
|
|
|
:param df: Data to check for pandas timedelta dtype
|
|
:type df: DataFrame
|
|
:param time: Column to check for pandas timedelta dtype
|
|
:type df: DataFrame
|
|
:param time: Second column, if any, to check for pandas timedelta dtype
|
|
:type df: DataFrame
|
|
|
|
:return: (*df*, *time_units*)
|
|
|
|
* **df** (*DataFrame*) – Data with pandas timedelta dtypes converted, which is *not* copied
|
|
* **time_units** (*str*) – Human-readable description of the time unit, or *None* if not converted
|
|
"""
|
|
|
|
max_time = None
|
|
|
|
if df[time].dtype == '<m8[ns]':
|
|
df[time] = df[time].dt.total_seconds()
|
|
max_time = df[time].max()
|
|
|
|
if time2 and df[time2].dtype == '<m8[ns]':
|
|
df[time2] = df[time2].dt.total_seconds()
|
|
max_time = max(max_time or 0, df[time2].max())
|
|
|
|
if max_time is not None:
|
|
# Auto-detect best time units
|
|
if max_time > 365.24*24*60*60:
|
|
time_divider = 365.24*24*60*60
|
|
time_units = 'years'
|
|
elif max_time > 7*24*60*60 / 12:
|
|
time_divider = 7*24*60*60
|
|
time_units = 'weeks'
|
|
elif max_time > 24*60*60:
|
|
time_divider = 24*60*60
|
|
time_units = 'days'
|
|
elif max_time > 60*60:
|
|
time_divider = 60*60
|
|
time_units = 'hours'
|
|
elif max_time > 60:
|
|
time_divider = 60
|
|
time_units = 'minutes'
|
|
else:
|
|
time_divider = 1
|
|
time_units = 'seconds'
|
|
|
|
df[time] /= time_divider
|
|
|
|
if time2:
|
|
df[time2] /= time_divider
|
|
|
|
return df, time_units
|
|
else:
|
|
return df, None
|
|
|
|
def logrank(df, time, status, by, nan_policy='warn'):
|
|
"""
|
|
Perform the log-rank test for equality of survival functions
|
|
|
|
:param df: Data to perform the test on
|
|
:type df: DataFrame
|
|
:param time: Column in *df* for the time to event (numeric or timedelta)
|
|
:type time: str
|
|
:param status: Column in *df* for the status variable (True/False or 1/0)
|
|
:type status: str
|
|
:param by: Column in *df* to stratify by (categorical)
|
|
:type by: str
|
|
:param nan_policy: How to handle *nan* values (see :ref:`nan-handling`)
|
|
:type nan_policy: str
|
|
|
|
:rtype: :class:`yli.sig_tests.ChiSquaredResult`
|
|
"""
|
|
|
|
# TODO: Example
|
|
|
|
# Check for/clean NaNs
|
|
df = check_nan(df[[time, status, by]], nan_policy)
|
|
|
|
if df[time].dtype == '<m8[ns]':
|
|
df[time] = df[time].dt.total_seconds()
|
|
|
|
statistic, pvalue = sm.duration.survdiff(df[time], df[status], df[by])
|
|
|
|
return ChiSquaredResult(statistic=statistic, dof=1, pvalue=pvalue)
|