Add model_kwargs, fit_kwargs and common arguments to yli.regress
This commit is contained in:
parent
f407c5a44f
commit
6206723713
@ -342,6 +342,9 @@ class CategoricalTerm:
|
||||
def regress(
|
||||
model_class, df, dep, formula, *,
|
||||
nan_policy='warn',
|
||||
model_kwargs=None, fit_kwargs=None,
|
||||
family=None, # common model_kwargs
|
||||
cov_type=None, maxiter=None, start_params=None, # common fit_kwargs
|
||||
bool_baselevels=False, exp=None
|
||||
):
|
||||
"""
|
||||
@ -351,9 +354,25 @@ def regress(
|
||||
exp: Report exponentiated parameters rather than raw parameters
|
||||
"""
|
||||
|
||||
# Populate model_kwargs
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
if family is not None:
|
||||
model_kwargs['family'] = family
|
||||
|
||||
# Populate fit_kwargs
|
||||
if fit_kwargs is None:
|
||||
fit_kwargs = {}
|
||||
if cov_type is not None:
|
||||
fit_kwargs['cov_type'] = cov_type
|
||||
if maxiter is not None:
|
||||
fit_kwargs['maxiter'] = maxiter
|
||||
if start_params is not None:
|
||||
fit_kwargs['start_params'] = start_params
|
||||
|
||||
# Autodetect whether to exponentiate
|
||||
if exp is None:
|
||||
if model_class is sm.Logit or model_class is PenalisedLogit:
|
||||
if model_class in (sm.Logit, sm.Poisson, PenalisedLogit):
|
||||
exp = True
|
||||
else:
|
||||
exp = False
|
||||
@ -372,8 +391,8 @@ def regress(
|
||||
df[col] = df[col].astype('float64')
|
||||
|
||||
# Fit model
|
||||
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df)
|
||||
result = model.fit()
|
||||
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df, **model_kwargs)
|
||||
result = model.fit(**fit_kwargs)
|
||||
|
||||
if isinstance(result, RegressionResult):
|
||||
# Already processed!
|
||||
|
Loading…
Reference in New Issue
Block a user