diff --git a/src/intcox.rs b/src/intcox.rs index 4d09688..87ad8e0 100644 --- a/src/intcox.rs +++ b/src/intcox.rs @@ -244,11 +244,8 @@ pub fn fit_interval_censored_cox(data_times: MatrixXx2, mut data_indep: DMa s = compute_s(&data, &lambda_new, &exp_z_beta); let ll_model_new = log_likelihood_obs(&s).sum(); - let mut converged = true; let ll_change = ll_model_new - ll_model; - if ll_change > ll_tolerance { - converged = false; - } + let converged = ll_change <= ll_tolerance; lambda = lambda_new; beta = beta_new; diff --git a/src/turnbull.rs b/src/turnbull.rs index d08e827..9d0b49d 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -43,9 +43,9 @@ pub struct TurnbullArgs { #[arg(long, default_value="1000")] max_iterations: u32, - /// Terminate algorithm when the absolute change in failure probability in each interval is less than this tolerance - #[arg(long, default_value="0.0001")] - fail_prob_tolerance: f64, + /// Terminate algorithm when the absolute change in log-likelihood is less than this tolerance + #[arg(long, default_value="0.01")] + ll_tolerance: f64, /// Method for computing standard error or survival probabilities #[arg(long, value_enum, default_value="oim")] @@ -74,7 +74,7 @@ pub fn main(args: TurnbullArgs) { // Fit regression let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); - let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.fail_prob_tolerance, args.se_method, args.zero_tolerance); + let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.ll_tolerance, args.se_method, args.zero_tolerance); // Display output match args.output { @@ -169,7 +169,7 @@ impl TurnbullData { } } -pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult { +pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult { // ---------------------- // Prepare for regression @@ -207,7 +207,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i progress_bar.reset(); progress_bar.println("Running iterative algorithm to fit Turnbull estimator"); - let s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s); + let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s); // Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0) let mut survival_prob: Vec = Vec::with_capacity(data.num_intervals() - 1); @@ -217,9 +217,6 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i survival_prob.push(acc); } - // Compute log-likelihood - let ll = compute_log_likelihood(&data, &s); - // -------------------------------------------------- // Compute standard errors for survival probabilities @@ -288,28 +285,33 @@ fn get_turnbull_intervals(data_times: &MatrixXx2) -> Vec<(f64, f64)> { return intervals; } -fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: Vec) -> Vec { +fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut s: Vec) -> (Vec, f64) { + // Get likelihood for each observation (denominator of μ_ij) + let mut likelihood_obs = get_likelihood_obs(data, &s); + let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum(); + let mut iteration = 1; loop { - // Get total failure probability for each observation (denominator of μ_ij) - let sum_fail_prob = get_sum_fail_prob(data, &s); + // Compute π_j to update s + let pi = compute_pi(data, &s, likelihood_obs); - // Compute π_j - let pi = compute_pi(data, &s, sum_fail_prob); + let likelihood_obs_new = get_likelihood_obs(data, &pi); + let ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); - let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); - - let converged = largest_delta_s <= fail_prob_tolerance; + let ll_change = ll_model_new - ll_model; + let converged = ll_change <= ll_tolerance; s = pi; + likelihood_obs = likelihood_obs_new; + ll_model = ll_model_new; - // Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations + // Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64; - let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64; + let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance.log10() * u64::MAX as f64) as u64; // Update progress bar progress_bar.set_position(progress_bar.position().max(progress3.max(progress2))); - progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s)); + progress_bar.set_message(format!("Iteration {} (LL = {:.4}, ΔLL = {:.4})", iteration + 1, ll_model, ll_change)); if converged { progress_bar.println(format!("Converged in {} iterations", iteration)); @@ -322,34 +324,34 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma } } - return s; + return (s, ll_model); } -fn get_sum_fail_prob(data: &TurnbullData, s: &Vec) -> Vec { +fn get_likelihood_obs(data: &TurnbullData, s: &Vec) -> Vec { return data.data_time_interval_indexes .par_iter() .map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum()) .collect(); } -fn compute_pi(data: &TurnbullData, s: &Vec, sum_fail_prob: Vec) -> Vec { +fn compute_pi(data: &TurnbullData, s: &Vec, likelihood_obs: Vec) -> Vec { /* let mut pi: Vec = vec![0.0; data.num_intervals()]; - for ((idx_left, idx_right), sum_fail_prob_i) in data.data_time_interval_indexes.iter().zip(sum_fail_prob.iter()) { + for ((idx_left, idx_right), likelihood_obs_i) in data.data_time_interval_indexes.iter().zip(likelihood_obs.iter()) { for j in *idx_left..(*idx_right + 1) { - pi[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64; + pi[j] += s[j] / likelihood_obs_i / data.num_obs() as f64; } } */ - let pi = data.data_time_interval_indexes.par_iter().zip(sum_fail_prob.par_iter()) + let pi = data.data_time_interval_indexes.par_iter().zip(likelihood_obs.par_iter()) .fold_with( // Compute the contributions to pi[j] for each observation and sum them in parallel using fold_with vec![0.0; data.num_intervals()], - |mut acc, ((idx_left, idx_right), sum_fail_prob_i)| { + |mut acc, ((idx_left, idx_right), likelihood_obs_i)| { // Contributions to pi[j] for the i-th observation for j in *idx_left..(*idx_right + 1) { - acc[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64; + acc[j] += s[j] / likelihood_obs_i / data.num_obs() as f64; } acc } @@ -366,17 +368,6 @@ fn compute_pi(data: &TurnbullData, s: &Vec, sum_fail_prob: Vec) -> Vec return pi; } -fn compute_log_likelihood(data: &TurnbullData, s: &Vec) -> f64 { - let mut ll = 0.0; - - for (idx_left, idx_right) in data.data_time_interval_indexes.iter() { - let likelihood_ob: f64 = s[*idx_left..(*idx_right + 1)].iter().sum(); - ll += likelihood_ob.ln(); - } - - return ll; -} - fn compute_hessian(data: &TurnbullData, s: &Vec) -> DMatrix { let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); diff --git a/tests/turnbull.rs b/tests/turnbull.rs index 04f44d9..41499a3 100644 --- a/tests/turnbull.rs +++ b/tests/turnbull.rs @@ -26,7 +26,7 @@ fn test_turnbull_minitab() { // Fit regression let progress_bar = ProgressBar::hidden(); - let result = turnbull::fit_turnbull(data_times, progress_bar, 500, 0.0001, turnbull::SEMethod::OIM, 0.0001); + let result = turnbull::fit_turnbull(data_times, progress_bar, 500, 0.01, turnbull::SEMethod::OIM, 0.0001); assert_eq!(result.failure_intervals[0], (20000.0, 30000.0)); assert_eq!(result.failure_intervals[1], (30000.0, 40000.0));