Implement yli.kaplanmeier
This commit is contained in:
parent
642d0d4e4f
commit
e83aa88b19
@ -1,5 +1,5 @@
|
|||||||
# scipy-yli: Helpful SciPy utilities and recipes
|
# scipy-yli: Helpful SciPy utilities and recipes
|
||||||
# Copyright © 2022 Lee Yingtong Li (RunasSudo)
|
# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Affero General Public License as published by
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
@ -22,6 +22,7 @@ from .graphs import init_fonts
|
|||||||
from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
|
from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
|
||||||
from .regress import OrdinalLogit, PenalisedLogit, logit_then_regress, regress, vif
|
from .regress import OrdinalLogit, PenalisedLogit, logit_then_regress, regress, vif
|
||||||
from .sig_tests import anova_oneway, auto_univariable, chi2, mannwhitney, pearsonr, spearman, ttest_ind
|
from .sig_tests import anova_oneway, auto_univariable, chi2, mannwhitney, pearsonr, spearman, ttest_ind
|
||||||
|
from .survival import kaplanmeier
|
||||||
from .utils import as_ordinal
|
from .utils import as_ordinal
|
||||||
|
|
||||||
def reload_me():
|
def reload_me():
|
||||||
|
104
yli/survival.py
Normal file
104
yli/survival.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# scipy-yli: Helpful SciPy utilities and recipes
|
||||||
|
# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from scipy import stats
|
||||||
|
import statsmodels.api as sm
|
||||||
|
|
||||||
|
from .config import config
|
||||||
|
from .utils import check_nan
|
||||||
|
|
||||||
|
def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'):
|
||||||
|
# TODO: Documentation
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Check for/clean NaNs
|
||||||
|
if by:
|
||||||
|
df = check_nan(df[[time, status, by]], nan_policy)
|
||||||
|
else:
|
||||||
|
df = check_nan(df[[time, status]], nan_policy)
|
||||||
|
|
||||||
|
if df[time].dtype == '<m8[ns]':
|
||||||
|
df[time] = df[time].dt.total_seconds()
|
||||||
|
|
||||||
|
# Auto-detect best time units
|
||||||
|
if df[time].max() > 365.24*24*60*60:
|
||||||
|
df[time] = df[time] / (365.24*24*60*60)
|
||||||
|
time_units = 'years'
|
||||||
|
elif df[time].max() > 7*24*60*60 / 12:
|
||||||
|
df[time] = df[time] / (7*24*60*60)
|
||||||
|
time_units = 'weeks'
|
||||||
|
elif df[time].max() > 24*60*60:
|
||||||
|
df[time] = df[time] / (24*60*60)
|
||||||
|
time_units = 'days'
|
||||||
|
elif df[time].max() > 60*60:
|
||||||
|
df[time] = df[time] / (60*60)
|
||||||
|
time_units = 'hours'
|
||||||
|
elif df[time].max() > 60:
|
||||||
|
df[time] = df[time] / 60
|
||||||
|
time_units = 'minutes'
|
||||||
|
else:
|
||||||
|
time_units = 'seconds'
|
||||||
|
else:
|
||||||
|
time_units = None
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
if by is not None:
|
||||||
|
# Group by independent variable
|
||||||
|
groups = df.groupby(by)
|
||||||
|
|
||||||
|
for group in groups.groups:
|
||||||
|
subset = groups.get_group(group)
|
||||||
|
handle = plot_survfunc(ax, subset[time], subset[status], ci)
|
||||||
|
handle.set_label('{} = {}'.format(by, group))
|
||||||
|
else:
|
||||||
|
# No grouping
|
||||||
|
plot_survfunc(ax, df[time], df[status], ci)
|
||||||
|
|
||||||
|
if time_units:
|
||||||
|
ax.set_xlabel('{} ({})'.format(time, time_units))
|
||||||
|
else:
|
||||||
|
ax.set_xlabel(time)
|
||||||
|
ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
def plot_survfunc(ax, time, status, ci):
|
||||||
|
# Estimate the survival function
|
||||||
|
sf = sm.SurvfuncRight(time, status)
|
||||||
|
|
||||||
|
# Draw straight lines
|
||||||
|
xpoints = sf.surv_times.repeat(2)[1:]
|
||||||
|
ypoints = sf.surv_prob.repeat(2)[:-1]
|
||||||
|
handle = ax.plot(xpoints, ypoints)[0]
|
||||||
|
|
||||||
|
if ci:
|
||||||
|
zstar = -stats.norm.ppf(config.alpha/2)
|
||||||
|
|
||||||
|
# Get confidence intervals
|
||||||
|
ci0 = sf.surv_prob - zstar * sf.surv_prob_se
|
||||||
|
ci1 = sf.surv_prob + zstar * sf.surv_prob_se
|
||||||
|
|
||||||
|
# Plot confidence intervals
|
||||||
|
ypoints0 = ci0.repeat(2)[:-1]
|
||||||
|
ypoints1 = ci1.repeat(2)[:-1]
|
||||||
|
|
||||||
|
ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
|
||||||
|
|
||||||
|
return handle
|
Loading…
Reference in New Issue
Block a user