Plotting functions for interval-censored Cox model
This commit is contained in:
parent
a0a9900dfa
commit
7746ca275e
@ -682,6 +682,12 @@ class IntervalCensoredCox(RegressionModel):
|
|||||||
def model_short_name(self):
|
def model_short_name(self):
|
||||||
return 'Interval-Censored Cox'
|
return 'Interval-Censored Cox'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# TODO: Documentation
|
||||||
|
self.lambda_ = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fit(cls, data_dep, data_ind):
|
def fit(cls, data_dep, data_ind):
|
||||||
if len(data_dep.columns) != 2:
|
if len(data_dep.columns) != 2:
|
||||||
@ -716,10 +722,58 @@ class IntervalCensoredCox(RegressionModel):
|
|||||||
pvalue=stats.norm.cdf(-np.abs(raw_param) / raw_se) * 2
|
pvalue=stats.norm.cdf(-np.abs(raw_param) / raw_se) * 2
|
||||||
) for raw_name, raw_param, raw_se in zip(data_ind.columns, raw_result['params'], raw_result['params_se'])}
|
) for raw_name, raw_param, raw_se in zip(data_ind.columns, raw_result['params'], raw_result['params_se'])}
|
||||||
|
|
||||||
|
result.cumulative_hazard = pd.Series(data=raw_result['cumulative_hazard'], index=raw_result['cumulative_hazard_times'])
|
||||||
|
|
||||||
result.ll_model = raw_result['ll_model']
|
result.ll_model = raw_result['ll_model']
|
||||||
result.ll_null = raw_result['ll_null']
|
result.ll_null = raw_result['ll_null']
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def survival_function(self):
|
||||||
|
# TODO: Documentation
|
||||||
|
|
||||||
|
return np.exp(-self.cumulative_hazard)
|
||||||
|
|
||||||
|
def plot_survival_function(self, ax=None):
|
||||||
|
# TODO: Documentation
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
if ax is None:
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
sf = self.survival_function()
|
||||||
|
|
||||||
|
# Draw straight lines
|
||||||
|
# np.concatenate(...) to force starting drawing from time 0, survival 100%
|
||||||
|
xpoints = np.concatenate([[0], sf.index]).repeat(2)[1:]
|
||||||
|
ypoints = np.concatenate([[1], sf]).repeat(2)[:-1]
|
||||||
|
|
||||||
|
ax.plot(xpoints, ypoints)
|
||||||
|
|
||||||
|
ax.set_xlabel('Analysis time')
|
||||||
|
ax.set_ylabel('Survival probability')
|
||||||
|
ax.set_xlim(left=0)
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
|
||||||
|
def plot_loglog_survival(self, ax=None):
|
||||||
|
# TODO: Documentation
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
if ax is None:
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
sf = self.survival_function()
|
||||||
|
|
||||||
|
# Draw straight lines
|
||||||
|
xpoints = np.log(sf.index).to_numpy().repeat(2)[1:]
|
||||||
|
ypoints = (-np.log(-np.log(sf))).to_numpy().repeat(2)[:-1]
|
||||||
|
|
||||||
|
ax.plot(xpoints, ypoints)
|
||||||
|
|
||||||
|
ax.set_xlabel('ln(Analysis time)')
|
||||||
|
ax.set_ylabel('−ln(−ln(Survival probability))')
|
||||||
|
|
||||||
class Logit(RegressionModel):
|
class Logit(RegressionModel):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user