Implement TimeVaryingCox

This commit is contained in:
RunasSudo 2023-07-17 01:34:34 +10:00
parent fd7384f810
commit 26a6766f0e
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
2 changed files with 56 additions and 1 deletions

View File

@ -20,7 +20,7 @@ from .descriptives import auto_correlations, auto_descriptives
from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist
from .graphs import init_fonts, HorizontalEffectPlot from .graphs import init_fonts, HorizontalEffectPlot
from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted
from .regress import Cox, GLM, IntervalCensoredCox, Logit, OLS, OrdinalLogit, PenalisedLogit, Poisson, regress, vif from .regress import Cox, GLM, IntervalCensoredCox, Logit, OLS, OrdinalLogit, PenalisedLogit, Poisson, TimeVaryingCox, regress, vif
from .sig_tests import anova_oneway, auto_univariable, chi2, mannwhitney, pearsonr, spearman, ttest_ind, ttest_ind_multiple from .sig_tests import anova_oneway, auto_univariable, chi2, mannwhitney, pearsonr, spearman, ttest_ind, ttest_ind_multiple
from .survival import kaplanmeier, logrank, turnbull from .survival import kaplanmeier, logrank, turnbull
from .utils import as_ordinal from .utils import as_ordinal

View File

@ -165,6 +165,12 @@ def regress(
if reduced is not None: if reduced is not None:
fit_kwargs['reduced'] = reduced fit_kwargs['reduced'] = reduced
# Bodge for TimeVaryingCox
if model_class.__name__ == 'TimeVaryingCox':
additional_columns.append('index')
additional_columns.append('start')
additional_columns.append('stop')
# Preprocess data, check for NaN and get design matrices # Preprocess data, check for NaN and get design matrices
df_ref = weakref.ref(df) df_ref = weakref.ref(df)
df_clean, dmatrices, dep_categories = df_to_dmatrices(df, dep, formula, nan_policy, additional_columns) df_clean, dmatrices, dep_categories = df_to_dmatrices(df, dep, formula, nan_policy, additional_columns)
@ -175,6 +181,10 @@ def regress(
if status is not None: if status is not None:
fit_kwargs['status'] = df_clean[status] fit_kwargs['status'] = df_clean[status]
# Bodge for TimeVaryingCox
if model_class.__name__ == 'TimeVaryingCox':
dmatrices = (dmatrices[0], dmatrices[1].join(df_clean[['index', 'start', 'stop']]))
# Fit model # Fit model
result = model_class.fit(dmatrices[0], dmatrices[1], **fit_kwargs) result = model_class.fit(dmatrices[0], dmatrices[1], **fit_kwargs)
@ -685,6 +695,8 @@ class Cox(RegressionModel):
result = cls() result = cls()
result.exp = True result.exp = True
result.cov_type = 'nonrobust' result.cov_type = 'nonrobust'
result.nevents = status.sum()
result.dof_model = len(data_ind.columns)
# Perform regression # Perform regression
raw_result = sm.PHReg(endog=data_dep, exog=data_ind, status=status, missing='raise').fit(disp=False) raw_result = sm.PHReg(endog=data_dep, exog=data_ind, status=status, missing='raise').fit(disp=False)
@ -1387,6 +1399,49 @@ class Poisson(RegressionModel):
return result return result
class TimeVaryingCox(RegressionModel):
# TODO: Documentation
# Requires df to be in lifelines long format
@property
def model_long_name(self):
return 'Time-Varying Cox Regression'
@property
def model_short_name(self):
return 'Time-Varying Cox'
@classmethod
def fit(cls, data_dep, data_ind, exposure=None, method='newton', maxiter=None, start_params=None):
result = cls()
result.exp = True
result.cov_type = 'nonrobust'
import lifelines
# Perform regression
ctv = lifelines.CoxTimeVaryingFitter(penalizer=0)
ctv.fit(data_dep.join(data_ind), id_col='index', event_col=data_dep.columns[0], start_col='start', stop_col='stop')
result.nobs = ctv._n_unique
result.nevents = ctv.event_observed.sum()
result.dof_model = len(ctv.params_)
result.ll_model = ctv.log_likelihood_
result.ll_null = ctv._log_likelihood_null
result.terms = {
raw_name: SingleTerm(
raw_name=raw_name,
beta=Estimate(raw_param, raw_ci[0], raw_ci[1]),
pvalue=raw_p
)
for raw_name, raw_param, raw_ci, raw_p in zip(ctv.params_.index, ctv.params_, ctv.confidence_intervals_.itertuples(index=False), ctv._compute_p_values())
}
result.vcov = ctv.variance_matrix_
return result
# ------------------ # ------------------
# Brant test helpers # Brant test helpers