diff --git a/src/turnbull.rs b/src/turnbull.rs index 2b64ea5..f536f3b 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -311,7 +311,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma // -------- // ICM step - let (p_new, s_new, ll_model_new) = do_icm_step(data, &p, &s, ll_model_after_em); + let (p_new, s_new, ll_model_new) = do_icm_step(data, &p, &s, ll_tolerance, ll_model_after_em); let ll_change = ll_model_new - ll_model; let converged = ll_change <= ll_tolerance; @@ -400,7 +400,7 @@ fn do_em_step(data: &TurnbullData, p: &Vec, s: &Vec) -> Vec { return p_new; } -fn do_icm_step(data: &TurnbullData, _p: &Vec, s: &Vec, ll_model: f64) -> (Vec, Vec, f64) { +fn do_icm_step(data: &TurnbullData, p: &Vec, s: &Vec, ll_tolerance: f64, ll_model: f64) -> (Vec, Vec, f64) { // Compute Λ, the cumulative hazard // Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted // The entry at lambda[j] corresponds to the survival immediately before time point j + 1 @@ -477,6 +477,12 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec, s: &Vec, ll_model: f64) return (p_new, s_new, ll_model_new); } + if ll_model - ll_model_new < ll_tolerance { + // LL decreased but by less than ll_tolerance + // This might happen because the EM algorithm already obtained the exact solution + return (p.clone(), s.clone(), ll_model); + } + step_size_exponent += 1; if step_size_exponent > 10 {