diff --git a/yli/survival.py b/yli/survival.py index aa36fb3..87bc414 100644 --- a/yli/survival.py +++ b/yli/survival.py @@ -22,7 +22,7 @@ from .config import config from .sig_tests import ChiSquaredResult from .utils import Estimate, check_nan -def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transform_y=None, nan_policy='warn'): +def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transform_y=None, nan_policy='warn', fig=None, ax=None): """ Generate a Kaplan–Meier plot @@ -59,7 +59,8 @@ def kaplanmeier(df, time, status, by=None, *, ci=True, transform_x=None, transfo # Covert timedelta to numeric df, time_units = survtime_to_numeric(df, time) - fig, ax = plt.subplots() + if ax is None: + fig, ax = plt.subplots() if by is not None: # Group by independent variable