From 6206723713cef5da7eff0326beb626bbc1fc4306 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 16 Oct 2022 02:31:08 +1100 Subject: [PATCH] Add model_kwargs, fit_kwargs and common arguments to yli.regress --- yli/regress.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/yli/regress.py b/yli/regress.py index 6bf0f61..75f8abd 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -342,6 +342,9 @@ class CategoricalTerm: def regress( model_class, df, dep, formula, *, nan_policy='warn', + model_kwargs=None, fit_kwargs=None, + family=None, # common model_kwargs + cov_type=None, maxiter=None, start_params=None, # common fit_kwargs bool_baselevels=False, exp=None ): """ @@ -351,9 +354,25 @@ def regress( exp: Report exponentiated parameters rather than raw parameters """ + # Populate model_kwargs + if model_kwargs is None: + model_kwargs = {} + if family is not None: + model_kwargs['family'] = family + + # Populate fit_kwargs + if fit_kwargs is None: + fit_kwargs = {} + if cov_type is not None: + fit_kwargs['cov_type'] = cov_type + if maxiter is not None: + fit_kwargs['maxiter'] = maxiter + if start_params is not None: + fit_kwargs['start_params'] = start_params + # Autodetect whether to exponentiate if exp is None: - if model_class is sm.Logit or model_class is PenalisedLogit: + if model_class in (sm.Logit, sm.Poisson, PenalisedLogit): exp = True else: exp = False @@ -372,8 +391,8 @@ def regress( df[col] = df[col].astype('float64') # Fit model - model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df) - result = model.fit() + model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df, **model_kwargs) + result = model.fit(**fit_kwargs) if isinstance(result, RegressionResult): # Already processed!