diff --git a/src/root_finding.rs b/src/root_finding.rs index 9062eaf..e091a30 100644 --- a/src/root_finding.rs +++ b/src/root_finding.rs @@ -102,6 +102,10 @@ impl AndersonBjorckRootFinder { } } + pub fn bounds(&self) -> (f64, f64) { + return (self.bound_lower, self.bound_upper); + } + pub fn precision(&self) -> f64 { return (self.bound_upper - self.bound_lower).abs(); } diff --git a/src/turnbull.rs b/src/turnbull.rs index ace0568..69017f9 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -19,9 +19,10 @@ const CHI2_1DF_95: f64 = 3.8414588; use std::fs::File; use std::io::{self, BufReader}; +use std::sync::{Arc, RwLock}; use clap::{Args, ValueEnum}; -use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; use nalgebra::{DMatrix, DVector, Matrix2xX}; use prettytable::{Table, format, row}; use rayon::prelude::*; @@ -249,9 +250,58 @@ pub fn fit_turnbull(data_times: Matrix2xX, progress_bar: ProgressBar, max_i progress_bar.reset(); progress_bar.println("Computing confidence intervals by likelihood ratio test"); - let confidence_intervals = (1..data.num_intervals()).into_par_iter() - .map(|j| survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j)) - .progress_with(progress_bar.clone()) + // (CI left, (CI left lower, CI left upper), CI right, (CI right lower, CI right upper)) + // TODO: Refactor this (unsafe code?) - each thread reads/writes only one value so there is no need for locking + let ci_with_bounds = Arc::new( + Vec::from_iter((1..data.num_intervals()).map(|_| RwLock::new((f64::NAN, (f64::NAN, f64::NAN), f64::NAN, (f64::NAN, f64::NAN))))) + ); + + // First do intervals with nonzero failure probability + (1..data.num_intervals()).into_par_iter() + .for_each(|j| { + if p[j - 1] <= 0.0001 { // To see if the survival probability at the j-th time index is the same as (j-1)-th, check the (j-1)-th failure probability + return; + } + + let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, None); + let mut r = ci_with_bounds[j - 1].write().unwrap(); + *r = ci; + + progress_bar.inc(1); + }); + + // Fill initial guesses for intervals with zero failure probability + let mut initial_guesses = Vec::with_capacity(data.num_intervals() - 1); + for j in 1..data.num_intervals() { + if p[j - 1] > 0.0001 { + let r = ci_with_bounds[j - 1].read().unwrap(); + initial_guesses.push(Some((r.1, r.3))); + } else if j >= 2 { + initial_guesses.push(initial_guesses[j - 2]); // Carry forward final bounds from last time point + } else { + initial_guesses.push(None); + } + } + + // Now do intervals with zero failure probability + (1..data.num_intervals()).into_par_iter() + .for_each(|j| { + if p[j - 1] > 0.0001 { + return; + } + + let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, initial_guesses[j - 1]); + let mut r = ci_with_bounds[j - 1].write().unwrap(); + *r = ci; + + progress_bar.inc(1); + }); + + let confidence_intervals = ci_with_bounds.iter() + .map(|x| { + let r = x.read().unwrap(); + (r.0, r.2) + }) .collect(); survival_prob_ci = Some(confidence_intervals); @@ -620,8 +670,10 @@ fn compute_hessian(data: &TurnbullData, p: &Vec) -> DMatrix { return hessian; } -fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec, ll_model: f64, s: &Vec, oim_se: &Vec, time_index: usize) -> (f64, f64) { +fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec, ll_model: f64, s: &Vec, oim_se: &Vec, time_index: usize, initial_guess: Option<((f64, f64), (f64, f64))>) -> (f64, (f64, f64), f64, (f64, f64)) { + // ------------------------------ // Compute lower confidence limit + let mut root_finder = AndersonBjorckRootFinder::new( 0.0, s[time_index], f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of @@ -632,22 +684,37 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case } + // Use initial guess if available + if let Some(((initial_left, initial_right), _)) = initial_guess { + let value_left = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_left)) - CHI2_1DF_95; + let value_right = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_right)) - CHI2_1DF_95; + + if value_left * value_right < 0.0 { + // Different signs, therefore this is a valid bracketing interval + root_finder = AndersonBjorckRootFinder::new( + initial_left, initial_right, + value_left, value_right // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of + ); + + ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case + } + } + let mut iteration = 1; loop { - // Get starting guess, constrained at time_index - let mut p_test = p.clone(); - let cur_survival_prob = s[time_index]; - let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - ci_estimate) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability - let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= ci_estimate / cur_survival_prob); + if root_finder.precision() <= ci_precision { + // Desired precision has been reached + // We check this first so that if an initial guess is supplied, we can terminate immediately here if it is sufficiently good + break; + } - let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); + let ll_test = profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, ci_estimate); let lr_statistic = 2.0 * (ll_model - ll_test); root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95); ci_estimate = root_finder.next_guess(); - if root_finder.precision() <= ci_precision { - // Desired precision has been reached + if (lr_statistic - CHI2_1DF_95).abs() <= ll_tolerance { break; } @@ -658,34 +725,53 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress } let ci_lower = ci_estimate; + let ci_lower_bounds = root_finder.bounds(); + // ------------------------------ // Compute upper confidence limit + root_finder = AndersonBjorckRootFinder::new( - s[time_index], 1.0, - -CHI2_1DF_95, f64::NAN // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of + 0.0, s[time_index], + f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of ); - ci_estimate = s[time_index] + Z_97_5 * oim_se[time_index - 1]; - if ci_estimate > 1.0 { + ci_estimate = s[time_index] - Z_97_5 * oim_se[time_index - 1]; + if ci_estimate < 0.0 { ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case } + // Use initial guess if available + if let Some((_, (initial_left, initial_right))) = initial_guess { + let value_left = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_left)) - CHI2_1DF_95; + let value_right = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_right)) - CHI2_1DF_95; + + if value_left * value_right < 0.0 { + // Different signs, therefore this is a valid bracketing interval + root_finder = AndersonBjorckRootFinder::new( + initial_left, initial_right, + value_left, value_right // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of + ); + + // TODO: Terminate if reached precision already + + ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case + } + } + let mut iteration = 1; loop { - // Get starting guess, constrained at time_index - let mut p_test = p.clone(); - let cur_survival_prob = s[time_index]; - let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - ci_estimate) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability - let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= ci_estimate / cur_survival_prob); + if root_finder.precision() <= ci_precision { + // Desired precision has been reached + break; + } - let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); + let ll_test = profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, ci_estimate); let lr_statistic = 2.0 * (ll_model - ll_test); root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95); ci_estimate = root_finder.next_guess(); - if root_finder.precision() <= ci_precision { - // Desired precision has been reached + if (lr_statistic - CHI2_1DF_95).abs() <= ll_tolerance { break; } @@ -696,8 +782,21 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress } let ci_upper = ci_estimate; + let ci_upper_bounds = root_finder.bounds(); - return (ci_lower, ci_upper); + return (ci_lower, ci_lower_bounds, ci_upper, ci_upper_bounds); +} + +fn profile_likelihood_survival_prob(data: &TurnbullData, progress_bar: &ProgressBar, max_iterations: u32, ll_tolerance: f64, p: &Vec, s: &Vec, time_index: usize, survival_prob: f64) -> f64 { + // Get starting guess, constrained at time_index + let mut p_test = p.clone(); + let cur_survival_prob = s[time_index]; + let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - survival_prob) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability + let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= survival_prob / cur_survival_prob); + + let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: survival_prob })); + + return ll_test; } #[derive(Serialize, Deserialize)]