From 4b9537643aa284887f0780d95978c368a4a18f30 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 16 Jul 2023 16:19:10 +1000 Subject: [PATCH] Add maxiter, tolerance parameters to IntervalCensoredCox --- docs/regress.rst | 2 ++ yli/regress.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/regress.rst b/docs/regress.rst index 5795754..f92a999 100644 --- a/docs/regress.rst +++ b/docs/regress.rst @@ -14,6 +14,8 @@ Functions Regression models ---------------------------- +.. autoclass:: yli.IntervalCensoredCox + .. autoclass:: yli.Logit .. autoclass:: yli.OrdinalLogit diff --git a/yli/regress.py b/yli/regress.py index 6d8c725..c39cf2e 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -102,7 +102,7 @@ def regress( *, nan_policy='warn', exposure=None, - method=None, maxiter=None, start_params=None, + method=None, maxiter=None, start_params=None, tolerance=None, reduced=None, bool_baselevels=False, exp=None ): @@ -124,6 +124,7 @@ def regress( :param method: See statsmodels *model.fit* :param maxiter: See statsmodels *model.fit* :param start_params: See statsmodels *model.fit* + :param tolerance: See statsmodels *model.fit* :param reduced: See :meth:`yli.IntervalCensoredCox` :param bool_baselevels: Show reference categories for boolean independent variables even if reference category is *False* :type bool_baselevels: bool @@ -151,6 +152,8 @@ def regress( fit_kwargs['maxiter'] = maxiter if start_params is not None: fit_kwargs['start_params'] = start_params + if tolerance is not None: + fit_kwargs['tolerance'] = tolerance if reduced is not None: fit_kwargs['reduced'] = reduced @@ -693,7 +696,7 @@ class IntervalCensoredCox(RegressionModel): self.lambda_ = None @classmethod - def fit(cls, data_dep, data_ind, *, reduced=False): + def fit(cls, data_dep, data_ind, *, reduced=False, maxiter=None, tolerance=None): if len(data_dep.columns) != 2: raise ValueError('IntervalCensoredCox requires left and right times') @@ -711,6 +714,12 @@ class IntervalCensoredCox(RegressionModel): intcox_args = [config.hpstat_path, 'intcox', '-', '--output', 'json'] if reduced: intcox_args.append('--reduced') + if maxiter: + intcox_args.append('--max-iterations') + intcox_args.append(str(maxiter)) + if tolerance: + intcox_args.append('--param-tolerance') + intcox_args.append(str(tolerance)) # Export data to CSV csv_buf = io.StringIO()