Plotting functions for interval-censored Cox model
This commit is contained in:
parent
a0a9900dfa
commit
7746ca275e
@ -682,6 +682,12 @@ class IntervalCensoredCox(RegressionModel):
|
||||
def model_short_name(self):
|
||||
return 'Interval-Censored Cox'
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# TODO: Documentation
|
||||
self.lambda_ = None
|
||||
|
||||
@classmethod
|
||||
def fit(cls, data_dep, data_ind):
|
||||
if len(data_dep.columns) != 2:
|
||||
@ -716,10 +722,58 @@ class IntervalCensoredCox(RegressionModel):
|
||||
pvalue=stats.norm.cdf(-np.abs(raw_param) / raw_se) * 2
|
||||
) for raw_name, raw_param, raw_se in zip(data_ind.columns, raw_result['params'], raw_result['params_se'])}
|
||||
|
||||
result.cumulative_hazard = pd.Series(data=raw_result['cumulative_hazard'], index=raw_result['cumulative_hazard_times'])
|
||||
|
||||
result.ll_model = raw_result['ll_model']
|
||||
result.ll_null = raw_result['ll_null']
|
||||
|
||||
return result
|
||||
|
||||
def survival_function(self):
|
||||
# TODO: Documentation
|
||||
|
||||
return np.exp(-self.cumulative_hazard)
|
||||
|
||||
def plot_survival_function(self, ax=None):
|
||||
# TODO: Documentation
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
sf = self.survival_function()
|
||||
|
||||
# Draw straight lines
|
||||
# np.concatenate(...) to force starting drawing from time 0, survival 100%
|
||||
xpoints = np.concatenate([[0], sf.index]).repeat(2)[1:]
|
||||
ypoints = np.concatenate([[1], sf]).repeat(2)[:-1]
|
||||
|
||||
ax.plot(xpoints, ypoints)
|
||||
|
||||
ax.set_xlabel('Analysis time')
|
||||
ax.set_ylabel('Survival probability')
|
||||
ax.set_xlim(left=0)
|
||||
ax.set_ylim(0, 1)
|
||||
|
||||
def plot_loglog_survival(self, ax=None):
|
||||
# TODO: Documentation
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
sf = self.survival_function()
|
||||
|
||||
# Draw straight lines
|
||||
xpoints = np.log(sf.index).to_numpy().repeat(2)[1:]
|
||||
ypoints = (-np.log(-np.log(sf))).to_numpy().repeat(2)[:-1]
|
||||
|
||||
ax.plot(xpoints, ypoints)
|
||||
|
||||
ax.set_xlabel('ln(Analysis time)')
|
||||
ax.set_ylabel('−ln(−ln(Survival probability))')
|
||||
|
||||
class Logit(RegressionModel):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user