Implement PenalisedLogit
This commit is contained in:
parent
51008296c2
commit
461e00df78
@ -121,3 +121,28 @@ def test_regress_logit_ol12_23():
|
||||
assert expbeta_gam.point == approx(1.169, abs=0.001)
|
||||
assert expbeta_gam.ci_lower == approx(0.924, abs=0.001)
|
||||
assert expbeta_gam.ci_upper == approx(1.477, abs=0.001)
|
||||
|
||||
def test_regress_penalisedlogit_kleinman():
|
||||
"""Compare yli.regress with yli.PenalisedLogit for http://sas-and-r.blogspot.com/2010/11/example-815-firth-logistic-regression.html"""
|
||||
|
||||
df = pd.DataFrame({
|
||||
'Pred': [1] * 20 + [0] * 220,
|
||||
'Outcome': [1] * 40 + [0] * 200
|
||||
})
|
||||
|
||||
result = yli.regress(yli.PenalisedLogit, df, 'Outcome', 'Pred', exp=False)
|
||||
|
||||
assert result.dof_model == 1
|
||||
assert result.beta['(Intercept)'].point == approx(-2.280389)
|
||||
assert result.beta['(Intercept)'].ci_lower == approx(-2.765427)
|
||||
assert result.beta['(Intercept)'].ci_upper == approx(-1.851695)
|
||||
assert result.pvalues['(Intercept)'] < 0.0001
|
||||
assert result.beta['Pred'].point == approx(5.993961)
|
||||
assert result.beta['Pred'].ci_lower == approx(3.947048)
|
||||
assert result.beta['Pred'].ci_upper == approx(10.852893)
|
||||
assert result.pvalues['Pred'] < 0.0001
|
||||
|
||||
lrtest_result = result.lrtest_null()
|
||||
assert lrtest_result.statistic == approx(78.95473)
|
||||
assert lrtest_result.dof == 1
|
||||
assert lrtest_result.pvalue < 0.0001
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist
|
||||
from .fs import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
|
||||
from .regress import regress, vif
|
||||
from .regress import PenalisedLogit, regress, vif
|
||||
from .sig_tests import chi2, mannwhitney, ttest_ind
|
||||
|
||||
def reload_me():
|
||||
|
@ -18,6 +18,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import patsy
|
||||
from scipy import stats
|
||||
import statsmodels
|
||||
import statsmodels.api as sm
|
||||
from statsmodels.iolib.table import SimpleTable
|
||||
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
||||
@ -286,7 +287,7 @@ def regress(
|
||||
|
||||
# Autodetect whether to exponentiate
|
||||
if exp is None:
|
||||
if model_class is sm.Logit:
|
||||
if model_class is sm.Logit or model_class is PenalisedLogit:
|
||||
exp = True
|
||||
else:
|
||||
exp = False
|
||||
@ -308,6 +309,11 @@ def regress(
|
||||
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df)
|
||||
result = model.fit()
|
||||
|
||||
if isinstance(result, RegressionResult):
|
||||
# Already processed!
|
||||
result.exp = exp
|
||||
return result
|
||||
|
||||
confint = result.conf_int()
|
||||
beta = {t: Estimate(b, confint[0][t], confint[1][t]) for t, b in result.params.items()}
|
||||
|
||||
@ -331,3 +337,52 @@ def regress(
|
||||
getattr(result, 'df_resid', None), getattr(result, 'rsquared', None), getattr(result, 'fvalue', None),
|
||||
exp
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Penalised logistic regression
|
||||
|
||||
class PenalisedLogit(statsmodels.discrete.discrete_model.BinaryModel):
|
||||
"""
|
||||
Statsmodel-compatible model for computing Firth penalised logistic regression
|
||||
Uses R "logistf" library
|
||||
|
||||
NB: This class expects to be used in the context of yli.regress()
|
||||
"""
|
||||
|
||||
def fit(self):
|
||||
import rpy2.robjects as ro
|
||||
import rpy2.robjects.packages
|
||||
import rpy2.robjects.pandas2ri
|
||||
|
||||
# Assume data is already cleaned from regress()
|
||||
df = self.data.frame.copy()
|
||||
|
||||
# Convert bool to int otherwise rpy2 chokes
|
||||
df = df.replace({False: 0, True: 1})
|
||||
|
||||
# Import logistf
|
||||
ro.packages.importr('logistf')
|
||||
|
||||
with ro.conversion.localconverter(ro.default_converter + ro.pandas2ri.converter):
|
||||
with ro.local_context() as lc:
|
||||
# Convert DataFrame to R
|
||||
lc['df'] = df
|
||||
|
||||
# Transfer other parameters to R
|
||||
lc['formula_'] = self.formula
|
||||
|
||||
# Fit the model
|
||||
model = ro.r('logistf(formula_, data=df)')
|
||||
|
||||
beta = {t: Estimate(b, ci0, ci1) for t, b, ci0, ci1 in zip(model['terms'], model['coefficients'], model['ci.lower'], model['ci.upper'])}
|
||||
pvalues = {t: p for t, p in zip(model['terms'], model['prob'])}
|
||||
|
||||
return RegressionResult(
|
||||
model,
|
||||
'Penalised Logistic Regression', 'Logit', 'Penalised ML',
|
||||
self.endog_names, model['n'][0], model['df'][0], datetime.now(),
|
||||
beta, pvalues,
|
||||
model['loglik'][0], model['loglik'][1],
|
||||
None, None, None,
|
||||
None # Set exp in regress()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user