turnbull: Rename other references from s to p post change to EM-ICM

This commit is contained in:
RunasSudo 2023-10-29 12:45:59 +11:00
parent f4d436e608
commit c87f42a042
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A

View File

@ -191,9 +191,9 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
(left_index, right_index)
}).collect();
// Initialise s
// Initialise p
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
let s = vec![1.0 / intervals.len() as f64; intervals.len()];
let p = vec![1.0 / intervals.len() as f64; intervals.len()];
let mut data = TurnbullData {
data_time_interval_indexes: data_time_interval_indexes,
@ -208,20 +208,20 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
progress_bar.reset();
progress_bar.println("Running EM-ICM algorithm to fit Turnbull estimator");
let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s);
let (p, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, p);
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
let mut acc = 1.0;
for j in 0..(data.num_intervals() - 1) {
acc -= s[j];
acc -= p[j];
survival_prob.push(acc);
}
// --------------------------------------------------
// Compute standard errors for survival probabilities
let hessian = compute_hessian(&data, &s);
let hessian = compute_hessian(&data, &p);
let mut survival_prob_se: DVector<f64>;
@ -233,7 +233,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
}
SEMethod::OIMDropZeros => {
// Drop rows/columns of Hessian corresponding to intervals with zero failure probability
let nonzero_intervals: Vec<usize> = (0..(data.num_intervals() - 1)).filter(|i| s[*i] > zero_tolerance).collect();
let nonzero_intervals: Vec<usize> = (0..(data.num_intervals() - 1)).filter(|i| p[*i] > zero_tolerance).collect();
let mut hessian_nonzero: DMatrix<f64> = DMatrix::zeros(nonzero_intervals.len(), nonzero_intervals.len());
for (nonzero_index1, orig_index1) in nonzero_intervals.iter().enumerate() {
@ -262,7 +262,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
return TurnbullResult {
failure_intervals: data.intervals,
failure_prob: s,
failure_prob: p,
survival_prob: survival_prob,
survival_prob_se: survival_prob_se.data.as_vec().clone(),
ll_model: ll,