Implement PenalisedLogit

This commit is contained in:
RunasSudo 2022-10-13 17:23:29 +11:00
parent 51008296c2
commit 461e00df78
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
3 changed files with 82 additions and 2 deletions

View File

@ -121,3 +121,28 @@ def test_regress_logit_ol12_23():
assert expbeta_gam.point == approx(1.169, abs=0.001)
assert expbeta_gam.ci_lower == approx(0.924, abs=0.001)
assert expbeta_gam.ci_upper == approx(1.477, abs=0.001)
def test_regress_penalisedlogit_kleinman():
"""Compare yli.regress with yli.PenalisedLogit for http://sas-and-r.blogspot.com/2010/11/example-815-firth-logistic-regression.html"""
df = pd.DataFrame({
'Pred': [1] * 20 + [0] * 220,
'Outcome': [1] * 40 + [0] * 200
})
result = yli.regress(yli.PenalisedLogit, df, 'Outcome', 'Pred', exp=False)
assert result.dof_model == 1
assert result.beta['(Intercept)'].point == approx(-2.280389)
assert result.beta['(Intercept)'].ci_lower == approx(-2.765427)
assert result.beta['(Intercept)'].ci_upper == approx(-1.851695)
assert result.pvalues['(Intercept)'] < 0.0001
assert result.beta['Pred'].point == approx(5.993961)
assert result.beta['Pred'].ci_lower == approx(3.947048)
assert result.beta['Pred'].ci_upper == approx(10.852893)
assert result.pvalues['Pred'] < 0.0001
lrtest_result = result.lrtest_null()
assert lrtest_result.statistic == approx(78.95473)
assert lrtest_result.dof == 1
assert lrtest_result.pvalue < 0.0001

View File

@ -16,7 +16,7 @@
from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist
from .fs import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
from .regress import regress, vif
from .regress import PenalisedLogit, regress, vif
from .sig_tests import chi2, mannwhitney, ttest_ind
def reload_me():

View File

@ -18,6 +18,7 @@ import numpy as np
import pandas as pd
import patsy
from scipy import stats
import statsmodels
import statsmodels.api as sm
from statsmodels.iolib.table import SimpleTable
from statsmodels.stats.outliers_influence import variance_inflation_factor
@ -286,7 +287,7 @@ def regress(
# Autodetect whether to exponentiate
if exp is None:
if model_class is sm.Logit:
if model_class is sm.Logit or model_class is PenalisedLogit:
exp = True
else:
exp = False
@ -308,6 +309,11 @@ def regress(
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df)
result = model.fit()
if isinstance(result, RegressionResult):
# Already processed!
result.exp = exp
return result
confint = result.conf_int()
beta = {t: Estimate(b, confint[0][t], confint[1][t]) for t, b in result.params.items()}
@ -331,3 +337,52 @@ def regress(
getattr(result, 'df_resid', None), getattr(result, 'rsquared', None), getattr(result, 'fvalue', None),
exp
)
# -----------------------------
# Penalised logistic regression
class PenalisedLogit(statsmodels.discrete.discrete_model.BinaryModel):
"""
Statsmodel-compatible model for computing Firth penalised logistic regression
Uses R "logistf" library
NB: This class expects to be used in the context of yli.regress()
"""
def fit(self):
import rpy2.robjects as ro
import rpy2.robjects.packages
import rpy2.robjects.pandas2ri
# Assume data is already cleaned from regress()
df = self.data.frame.copy()
# Convert bool to int otherwise rpy2 chokes
df = df.replace({False: 0, True: 1})
# Import logistf
ro.packages.importr('logistf')
with ro.conversion.localconverter(ro.default_converter + ro.pandas2ri.converter):
with ro.local_context() as lc:
# Convert DataFrame to R
lc['df'] = df
# Transfer other parameters to R
lc['formula_'] = self.formula
# Fit the model
model = ro.r('logistf(formula_, data=df)')
beta = {t: Estimate(b, ci0, ci1) for t, b, ci0, ci1 in zip(model['terms'], model['coefficients'], model['ci.lower'], model['ci.upper'])}
pvalues = {t: p for t, p in zip(model['terms'], model['prob'])}
return RegressionResult(
model,
'Penalised Logistic Regression', 'Logit', 'Penalised ML',
self.endog_names, model['n'][0], model['df'][0], datetime.now(),
beta, pvalues,
model['loglik'][0], model['loglik'][1],
None, None, None,
None # Set exp in regress()
)