Add maxiter, tolerance parameters to IntervalCensoredCox

This commit is contained in:
RunasSudo 2023-07-16 16:19:10 +10:00
parent 71b714ab7d
commit 4b9537643a
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
2 changed files with 13 additions and 2 deletions

View File

@ -14,6 +14,8 @@ Functions
Regression models
----------------------------
.. autoclass:: yli.IntervalCensoredCox
.. autoclass:: yli.Logit
.. autoclass:: yli.OrdinalLogit

View File

@ -102,7 +102,7 @@ def regress(
*,
nan_policy='warn',
exposure=None,
method=None, maxiter=None, start_params=None,
method=None, maxiter=None, start_params=None, tolerance=None,
reduced=None,
bool_baselevels=False, exp=None
):
@ -124,6 +124,7 @@ def regress(
:param method: See statsmodels *model.fit*
:param maxiter: See statsmodels *model.fit*
:param start_params: See statsmodels *model.fit*
:param tolerance: See statsmodels *model.fit*
:param reduced: See :meth:`yli.IntervalCensoredCox`
:param bool_baselevels: Show reference categories for boolean independent variables even if reference category is *False*
:type bool_baselevels: bool
@ -151,6 +152,8 @@ def regress(
fit_kwargs['maxiter'] = maxiter
if start_params is not None:
fit_kwargs['start_params'] = start_params
if tolerance is not None:
fit_kwargs['tolerance'] = tolerance
if reduced is not None:
fit_kwargs['reduced'] = reduced
@ -693,7 +696,7 @@ class IntervalCensoredCox(RegressionModel):
self.lambda_ = None
@classmethod
def fit(cls, data_dep, data_ind, *, reduced=False):
def fit(cls, data_dep, data_ind, *, reduced=False, maxiter=None, tolerance=None):
if len(data_dep.columns) != 2:
raise ValueError('IntervalCensoredCox requires left and right times')
@ -711,6 +714,12 @@ class IntervalCensoredCox(RegressionModel):
intcox_args = [config.hpstat_path, 'intcox', '-', '--output', 'json']
if reduced:
intcox_args.append('--reduced')
if maxiter:
intcox_args.append('--max-iterations')
intcox_args.append(str(maxiter))
if tolerance:
intcox_args.append('--param-tolerance')
intcox_args.append(str(tolerance))
# Export data to CSV
csv_buf = io.StringIO()