diff --git a/src/turnbull.rs b/src/turnbull.rs index 33f631b..ac60a88 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -172,18 +172,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i // Prepare for regression // Get Turnbull intervals - let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left) - all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open - all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true))); - all_time_points.dedup(); - all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap()); - - let mut intervals: Vec<(f64, f64)> = Vec::new(); - for i in 1..all_time_points.len() { - if all_time_points[i - 1].1 == true && all_time_points[i].1 == false { - intervals.push((all_time_points[i - 1].0, all_time_points[i].0)); - } - } + let intervals = get_turnbull_intervals(&data_times); // Recode times as indexes let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| { @@ -202,7 +191,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i // Initialise s let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64); - let data = TurnbullData { + let mut data = TurnbullData { data_time_interval_indexes: data_time_interval_indexes, intervals: intervals, }; @@ -215,48 +204,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i progress_bar.reset(); progress_bar.println("Running iterative algorithm to fit Turnbull estimator"); - let mut iteration = 1; - loop { - // Get total failure probability for each observation (denominator of μ_ij) - let sum_fail_prob = DVector::from_iterator( - data.num_obs(), - data.data_time_interval_indexes - .iter() - .map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum()) - ); - - // Compute π_j - let mut pi: DVector = DVector::zeros(data.num_intervals()); - for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() { - for j in *idx_left..(*idx_right + 1) { - pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64; - } - } - - let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); - - let converged = largest_delta_s <= fail_prob_tolerance; - - s = pi; - - // Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations - let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64; - let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64; - - // Update progress bar - progress_bar.set_position(progress_bar.position().max(progress3.max(progress2))); - progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s)); - - if converged { - progress_bar.println(format!("Converged in {} iterations", iteration)); - break; - } - - iteration += 1; - if iteration > max_iterations { - panic!("Exceeded --max-iterations"); - } - } + s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s); // Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0) let mut survival_prob: Vec = Vec::with_capacity(data.num_intervals() - 1); @@ -274,32 +222,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i progress_bar.reset(); progress_bar.println("Computing standard errors for survival probabilities"); - let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); - - for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) { - let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum(); - hessian_denominator = hessian_denominator.powi(2); - - let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case - let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian - - for h in idx_start..idx_end { - let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 }; - let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 }; - - hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator; - - for k in idx_start..h { - let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 }; - let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 }; - - let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator; - hessian[(h, k)] -= value; - hessian[(k, h)] -= value; - } - } - } - progress_bar.finish(); + let hessian = compute_hessian(&data, progress_bar.clone(), &s); let mut survival_prob_se: DVector; @@ -346,6 +269,101 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i }; } +fn get_turnbull_intervals(data_times: &MatrixXx2) -> Vec<(f64, f64)> { + let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left) + all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open + all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true))); + all_time_points.dedup(); + all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap()); + + let mut intervals: Vec<(f64, f64)> = Vec::new(); + for i in 1..all_time_points.len() { + if all_time_points[i - 1].1 == true && all_time_points[i].1 == false { + intervals.push((all_time_points[i - 1].0, all_time_points[i].0)); + } + } + + return intervals; +} + +fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: DVector) -> DVector { + let mut iteration = 1; + loop { + // Get total failure probability for each observation (denominator of μ_ij) + let sum_fail_prob = DVector::from_iterator( + data.num_obs(), + data.data_time_interval_indexes + .iter() + .map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum()) + ); + + // Compute π_j + let mut pi: DVector = DVector::zeros(data.num_intervals()); + for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() { + for j in *idx_left..(*idx_right + 1) { + pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64; + } + } + + let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); + + let converged = largest_delta_s <= fail_prob_tolerance; + + s = pi; + + // Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations + let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64; + let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64; + + // Update progress bar + progress_bar.set_position(progress_bar.position().max(progress3.max(progress2))); + progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s)); + + if converged { + progress_bar.println(format!("Converged in {} iterations", iteration)); + break; + } + + iteration += 1; + if iteration > max_iterations { + panic!("Exceeded --max-iterations"); + } + } + + return s; +} + +fn compute_hessian(data: &TurnbullData, progress_bar: ProgressBar, s: &DVector) -> DMatrix { + let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); + + for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) { + let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum(); + hessian_denominator = hessian_denominator.powi(2); + + let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case + let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian + + for h in idx_start..idx_end { + let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 }; + let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 }; + + hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator; + + for k in idx_start..h { + let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 }; + let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 }; + + let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator; + hessian[(h, k)] -= value; + hessian[(k, h)] -= value; + } + } + } + progress_bar.finish(); + + return hessian; +} + #[derive(Serialize, Deserialize)] pub struct TurnbullResult { pub failure_intervals: Vec<(f64, f64)>,