diff --git a/yli/regress.py b/yli/regress.py index 07d1db8..4c220c1 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -682,6 +682,12 @@ class IntervalCensoredCox(RegressionModel): def model_short_name(self): return 'Interval-Censored Cox' + def __init__(self): + super().__init__() + + # TODO: Documentation + self.lambda_ = None + @classmethod def fit(cls, data_dep, data_ind): if len(data_dep.columns) != 2: @@ -716,10 +722,58 @@ class IntervalCensoredCox(RegressionModel): pvalue=stats.norm.cdf(-np.abs(raw_param) / raw_se) * 2 ) for raw_name, raw_param, raw_se in zip(data_ind.columns, raw_result['params'], raw_result['params_se'])} + result.cumulative_hazard = pd.Series(data=raw_result['cumulative_hazard'], index=raw_result['cumulative_hazard_times']) + result.ll_model = raw_result['ll_model'] result.ll_null = raw_result['ll_null'] return result + + def survival_function(self): + # TODO: Documentation + + return np.exp(-self.cumulative_hazard) + + def plot_survival_function(self, ax=None): + # TODO: Documentation + + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots() + + sf = self.survival_function() + + # Draw straight lines + # np.concatenate(...) to force starting drawing from time 0, survival 100% + xpoints = np.concatenate([[0], sf.index]).repeat(2)[1:] + ypoints = np.concatenate([[1], sf]).repeat(2)[:-1] + + ax.plot(xpoints, ypoints) + + ax.set_xlabel('Analysis time') + ax.set_ylabel('Survival probability') + ax.set_xlim(left=0) + ax.set_ylim(0, 1) + + def plot_loglog_survival(self, ax=None): + # TODO: Documentation + + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots() + + sf = self.survival_function() + + # Draw straight lines + xpoints = np.log(sf.index).to_numpy().repeat(2)[1:] + ypoints = (-np.log(-np.log(sf))).to_numpy().repeat(2)[:-1] + + ax.plot(xpoints, ypoints) + + ax.set_xlabel('ln(Analysis time)') + ax.set_ylabel('−ln(−ln(Survival probability))') class Logit(RegressionModel): """