diff --git a/yli/__init__.py b/yli/__init__.py index c49b56e..1afea8d 100644 --- a/yli/__init__.py +++ b/yli/__init__.py @@ -18,7 +18,7 @@ from .bayes_factors import bayesfactor_afbf from .config import config from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted -from .regress import PenalisedLogit, regress, vif +from .regress import PenalisedLogit, logit_then_regress, regress, vif from .sig_tests import anova_oneway, chi2, mannwhitney, pearsonr, ttest_ind def reload_me(): diff --git a/yli/regress.py b/yli/regress.py index eeb7cc6..ba951ab 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -462,6 +462,21 @@ def regress( exp ) +def logit_then_regress(model_class, df, dep, formula, *, nan_policy='warn', **kwargs): + """Perform logistic regression, then use parameters as start parameters for desired regression""" + + # Check for/clean NaNs + # Do this once here so we only get 1 warning + df = df[[dep] + cols_for_formula(formula, df)] + df = check_nan(df, nan_policy) + + # Perform logistic regression + logit_result = regress(sm.Logit, df, dep, formula, **kwargs) + logit_params = logit_result.raw_result.params + + # Perform desired regression + return regress(model_class, df, dep, formula, start_params=logit_params, **kwargs) + # ----------------------------- # Penalised logistic regression