diff --git a/src/turnbull.rs b/src/turnbull.rs index 69017f9..6874a79 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -19,7 +19,7 @@ const CHI2_1DF_95: f64 = 3.8414588; use std::fs::File; use std::io::{self, BufReader}; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use clap::{Args, ValueEnum}; use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; @@ -250,32 +250,27 @@ pub fn fit_turnbull(data_times: Matrix2xX, progress_bar: ProgressBar, max_i progress_bar.reset(); progress_bar.println("Computing confidence intervals by likelihood ratio test"); - // (CI left, (CI left lower, CI left upper), CI right, (CI right lower, CI right upper)) - // TODO: Refactor this (unsafe code?) - each thread reads/writes only one value so there is no need for locking - let ci_with_bounds = Arc::new( - Vec::from_iter((1..data.num_intervals()).map(|_| RwLock::new((f64::NAN, (f64::NAN, f64::NAN), f64::NAN, (f64::NAN, f64::NAN))))) - ); - // First do intervals with nonzero failure probability - (1..data.num_intervals()).into_par_iter() - .for_each(|j| { + let ci_with_bounds: Vec<(f64, (f64, f64), f64, (f64, f64))> = (1..data.num_intervals()).into_par_iter() + .map(|j| { if p[j - 1] <= 0.0001 { // To see if the survival probability at the j-th time index is the same as (j-1)-th, check the (j-1)-th failure probability - return; + return (f64::NAN, (f64::NAN, f64::NAN), f64::NAN, (f64::NAN, f64::NAN)); } let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, None); - let mut r = ci_with_bounds[j - 1].write().unwrap(); - *r = ci; progress_bar.inc(1); - }); + return ci; // (CI left, (CI left lower, CI left upper), CI right, (CI right lower, CI right upper)) + }) + .collect(); + + let ci_with_bounds = Arc::new(ci_with_bounds); // Fill initial guesses for intervals with zero failure probability let mut initial_guesses = Vec::with_capacity(data.num_intervals() - 1); for j in 1..data.num_intervals() { if p[j - 1] > 0.0001 { - let r = ci_with_bounds[j - 1].read().unwrap(); - initial_guesses.push(Some((r.1, r.3))); + initial_guesses.push(Some((ci_with_bounds[j - 1].1, ci_with_bounds[j - 1].3))); } else if j >= 2 { initial_guesses.push(initial_guesses[j - 2]); // Carry forward final bounds from last time point } else { @@ -284,24 +279,21 @@ pub fn fit_turnbull(data_times: Matrix2xX, progress_bar: ProgressBar, max_i } // Now do intervals with zero failure probability - (1..data.num_intervals()).into_par_iter() - .for_each(|j| { + let ci_with_bounds: Vec<(f64, (f64, f64), f64, (f64, f64))> = (1..data.num_intervals()).into_par_iter() + .map(|j| { if p[j - 1] > 0.0001 { - return; + return ci_with_bounds[j - 1]; } let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, initial_guesses[j - 1]); - let mut r = ci_with_bounds[j - 1].write().unwrap(); - *r = ci; progress_bar.inc(1); - }); + return ci; + }) + .collect(); let confidence_intervals = ci_with_bounds.iter() - .map(|x| { - let r = x.read().unwrap(); - (r.0, r.2) - }) + .map(|x| (x.0, x.2)) .collect(); survival_prob_ci = Some(confidence_intervals);