diff --git a/yli/__init__.py b/yli/__init__.py
index 7bf7c24..7c88db0 100644
--- a/yli/__init__.py
+++ b/yli/__init__.py
@@ -1,5 +1,5 @@
# scipy-yli: Helpful SciPy utilities and recipes
-# Copyright © 2022 Lee Yingtong Li (RunasSudo)
+# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -22,6 +22,7 @@ from .graphs import init_fonts
from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
from .regress import OrdinalLogit, PenalisedLogit, logit_then_regress, regress, vif
from .sig_tests import anova_oneway, auto_univariable, chi2, mannwhitney, pearsonr, spearman, ttest_ind
+from .survival import kaplanmeier
from .utils import as_ordinal
def reload_me():
diff --git a/yli/survival.py b/yli/survival.py
new file mode 100644
index 0000000..4e05792
--- /dev/null
+++ b/yli/survival.py
@@ -0,0 +1,104 @@
+# scipy-yli: Helpful SciPy utilities and recipes
+# Copyright © 2022–2023 Lee Yingtong Li (RunasSudo)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from scipy import stats
+import statsmodels.api as sm
+
+from .config import config
+from .utils import check_nan
+
+def kaplanmeier(df, time, status, by=None, ci=True, nan_policy='warn'):
+ # TODO: Documentation
+
+ import matplotlib.pyplot as plt
+
+ # Check for/clean NaNs
+ if by:
+ df = check_nan(df[[time, status, by]], nan_policy)
+ else:
+ df = check_nan(df[[time, status]], nan_policy)
+
+ if df[time].dtype == ' 365.24*24*60*60:
+ df[time] = df[time] / (365.24*24*60*60)
+ time_units = 'years'
+ elif df[time].max() > 7*24*60*60 / 12:
+ df[time] = df[time] / (7*24*60*60)
+ time_units = 'weeks'
+ elif df[time].max() > 24*60*60:
+ df[time] = df[time] / (24*60*60)
+ time_units = 'days'
+ elif df[time].max() > 60*60:
+ df[time] = df[time] / (60*60)
+ time_units = 'hours'
+ elif df[time].max() > 60:
+ df[time] = df[time] / 60
+ time_units = 'minutes'
+ else:
+ time_units = 'seconds'
+ else:
+ time_units = None
+
+ fig, ax = plt.subplots()
+
+ if by is not None:
+ # Group by independent variable
+ groups = df.groupby(by)
+
+ for group in groups.groups:
+ subset = groups.get_group(group)
+ handle = plot_survfunc(ax, subset[time], subset[status], ci)
+ handle.set_label('{} = {}'.format(by, group))
+ else:
+ # No grouping
+ plot_survfunc(ax, df[time], df[status], ci)
+
+ if time_units:
+ ax.set_xlabel('{} ({})'.format(time, time_units))
+ else:
+ ax.set_xlabel(time)
+ ax.set_ylabel('Survival probability ({:.0%} CI)'.format(1-config.alpha) if ci else 'Survival probability')
+ ax.set_ylim(0, 1)
+ ax.legend()
+
+ return ax
+
+def plot_survfunc(ax, time, status, ci):
+ # Estimate the survival function
+ sf = sm.SurvfuncRight(time, status)
+
+ # Draw straight lines
+ xpoints = sf.surv_times.repeat(2)[1:]
+ ypoints = sf.surv_prob.repeat(2)[:-1]
+ handle = ax.plot(xpoints, ypoints)[0]
+
+ if ci:
+ zstar = -stats.norm.ppf(config.alpha/2)
+
+ # Get confidence intervals
+ ci0 = sf.surv_prob - zstar * sf.surv_prob_se
+ ci1 = sf.surv_prob + zstar * sf.surv_prob_se
+
+ # Plot confidence intervals
+ ypoints0 = ci0.repeat(2)[:-1]
+ ypoints1 = ci1.repeat(2)[:-1]
+
+ ax.fill_between(xpoints, ypoints0, ypoints1, alpha=0.3, label='_')
+
+ return handle