Plotting functions for interval-censored Cox model

This commit is contained in:
RunasSudo 2023-04-22 01:49:45 +10:00
parent a0a9900dfa
commit 7746ca275e
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A

View File

@ -682,6 +682,12 @@ class IntervalCensoredCox(RegressionModel):
def model_short_name(self):
return 'Interval-Censored Cox'
def __init__(self):
super().__init__()
# TODO: Documentation
self.lambda_ = None
@classmethod
def fit(cls, data_dep, data_ind):
if len(data_dep.columns) != 2:
@ -716,11 +722,59 @@ class IntervalCensoredCox(RegressionModel):
pvalue=stats.norm.cdf(-np.abs(raw_param) / raw_se) * 2
) for raw_name, raw_param, raw_se in zip(data_ind.columns, raw_result['params'], raw_result['params_se'])}
result.cumulative_hazard = pd.Series(data=raw_result['cumulative_hazard'], index=raw_result['cumulative_hazard_times'])
result.ll_model = raw_result['ll_model']
result.ll_null = raw_result['ll_null']
return result
def survival_function(self):
# TODO: Documentation
return np.exp(-self.cumulative_hazard)
def plot_survival_function(self, ax=None):
# TODO: Documentation
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots()
sf = self.survival_function()
# Draw straight lines
# np.concatenate(...) to force starting drawing from time 0, survival 100%
xpoints = np.concatenate([[0], sf.index]).repeat(2)[1:]
ypoints = np.concatenate([[1], sf]).repeat(2)[:-1]
ax.plot(xpoints, ypoints)
ax.set_xlabel('Analysis time')
ax.set_ylabel('Survival probability')
ax.set_xlim(left=0)
ax.set_ylim(0, 1)
def plot_loglog_survival(self, ax=None):
# TODO: Documentation
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots()
sf = self.survival_function()
# Draw straight lines
xpoints = np.log(sf.index).to_numpy().repeat(2)[1:]
ypoints = (-np.log(-np.log(sf))).to_numpy().repeat(2)[:-1]
ax.plot(xpoints, ypoints)
ax.set_xlabel('ln(Analysis time)')
ax.set_ylabel('−ln(−ln(Survival probability))')
class Logit(RegressionModel):
"""
Logistic regression