turnbull: Disregard ICM step when computing likelihood-ratio confidence intervals

ICM step performance is heavily degraded when constraints are required
It is much faster to rely on the EM step alone
1275% speedup!
This commit is contained in:
RunasSudo 2023-12-25 22:21:50 +11:00
parent c9a8b5b8a5
commit 307aff6f14
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 21 additions and 13 deletions

View File

@ -312,6 +312,8 @@ fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_it
// ------- // -------
// EM step // EM step
// TODO: Do EM step multiple times per ICM step?
let p_after_em = do_em_step(data, &p, &s, &constraint); let p_after_em = do_em_step(data, &p, &s, &constraint);
let s_after_em = p_to_s(&p_after_em); let s_after_em = p_to_s(&p_after_em);
@ -324,13 +326,18 @@ fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_it
// -------- // --------
// ICM step // ICM step
let (p_new, s_new, ll_model_new) = do_icm_step(data, &p, &s, ll_tolerance, &constraint, ll_model_after_em); let ll_model_new;
if constraint.is_none() {
(p, s, ll_model_new) = do_icm_step(data, &p, &s, ll_tolerance, ll_model_after_em);
} else {
// ICM step is very slow with constraints, so skip it and just do EM
ll_model_new = ll_model_after_em;
}
let ll_change = ll_model_new - ll_model; let ll_change = ll_model_new - ll_model;
let converged = ll_change <= ll_tolerance; let converged = ll_change <= ll_tolerance;
p = p_new;
s = s_new;
ll_model = ll_model_new; ll_model = ll_model_new;
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations // Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
@ -421,7 +428,7 @@ fn do_em_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>, constraint: &Opti
return p_new; return p_new;
} }
fn do_icm_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>, ll_tolerance: f64, constraint: &Option<Constraint>, ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) { fn do_icm_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>, ll_tolerance: f64, /* constraint: &Option<Constraint>, */ ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) {
// Compute Λ, the cumulative hazard // Compute Λ, the cumulative hazard
// Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted // Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted
// The entry at lambda[j] corresponds to the survival immediately before time point j + 1 // The entry at lambda[j] corresponds to the survival immediately before time point j + 1
@ -495,15 +502,16 @@ fn do_icm_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>, ll_tolerance: f6
ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
// Constrain if required // Constrain if required
if let Some(c) = constraint { // This is very slow, so support constraints only in the EM step
let cur_survival_prob = s_new[c.time_index]; //if let Some(c) = constraint {
let _ = &mut p_new[0..c.time_index].iter_mut().for_each(|x| *x *= (1.0 - c.survival_prob) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability // let cur_survival_prob = s_new[c.time_index];
let _ = &mut p_new[c.time_index..].iter_mut().for_each(|x| *x *= c.survival_prob / cur_survival_prob); // let _ = &mut p_new[0..c.time_index].iter_mut().for_each(|x| *x *= (1.0 - c.survival_prob) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability
// let _ = &mut p_new[c.time_index..].iter_mut().for_each(|x| *x *= c.survival_prob / cur_survival_prob);
s_new = p_to_s(&p_new); //
let likelihood_obs_new = get_likelihood_obs(data, &s_new); // s_new = p_to_s(&p_new);
ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); // let likelihood_obs_new = get_likelihood_obs(data, &s_new);
} // ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
//}
if ll_model_new > ll_model { if ll_model_new > ll_model {
return (p_new, s_new, ll_model_new); return (p_new, s_new, ll_model_new);