Add model_kwargs, fit_kwargs and common arguments to yli.regress

This commit is contained in:
RunasSudo 2022-10-16 02:31:08 +11:00
parent f407c5a44f
commit 6206723713
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 22 additions and 3 deletions

View File

@ -342,6 +342,9 @@ class CategoricalTerm:
def regress(
model_class, df, dep, formula, *,
nan_policy='warn',
model_kwargs=None, fit_kwargs=None,
family=None, # common model_kwargs
cov_type=None, maxiter=None, start_params=None, # common fit_kwargs
bool_baselevels=False, exp=None
):
"""
@ -351,9 +354,25 @@ def regress(
exp: Report exponentiated parameters rather than raw parameters
"""
# Populate model_kwargs
if model_kwargs is None:
model_kwargs = {}
if family is not None:
model_kwargs['family'] = family
# Populate fit_kwargs
if fit_kwargs is None:
fit_kwargs = {}
if cov_type is not None:
fit_kwargs['cov_type'] = cov_type
if maxiter is not None:
fit_kwargs['maxiter'] = maxiter
if start_params is not None:
fit_kwargs['start_params'] = start_params
# Autodetect whether to exponentiate
if exp is None:
if model_class is sm.Logit or model_class is PenalisedLogit:
if model_class in (sm.Logit, sm.Poisson, PenalisedLogit):
exp = True
else:
exp = False
@ -372,8 +391,8 @@ def regress(
df[col] = df[col].astype('float64')
# Fit model
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df)
result = model.fit()
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df, **model_kwargs)
result = model.fit(**fit_kwargs)
if isinstance(result, RegressionResult):
# Already processed!