turnbull: Clarify likelihood-ratio CI multithreading code
Remove unnecessary use of RwLock and use map-reduce instead
This commit is contained in:
parent
204571d6cb
commit
162e415e07
@ -19,7 +19,7 @@ const CHI2_1DF_95: f64 = 3.8414588;
|
|||||||
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufReader};
|
use std::io::{self, BufReader};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::Arc;
|
||||||
|
|
||||||
use clap::{Args, ValueEnum};
|
use clap::{Args, ValueEnum};
|
||||||
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||||
@ -250,32 +250,27 @@ pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_i
|
|||||||
progress_bar.reset();
|
progress_bar.reset();
|
||||||
progress_bar.println("Computing confidence intervals by likelihood ratio test");
|
progress_bar.println("Computing confidence intervals by likelihood ratio test");
|
||||||
|
|
||||||
// (CI left, (CI left lower, CI left upper), CI right, (CI right lower, CI right upper))
|
|
||||||
// TODO: Refactor this (unsafe code?) - each thread reads/writes only one value so there is no need for locking
|
|
||||||
let ci_with_bounds = Arc::new(
|
|
||||||
Vec::from_iter((1..data.num_intervals()).map(|_| RwLock::new((f64::NAN, (f64::NAN, f64::NAN), f64::NAN, (f64::NAN, f64::NAN)))))
|
|
||||||
);
|
|
||||||
|
|
||||||
// First do intervals with nonzero failure probability
|
// First do intervals with nonzero failure probability
|
||||||
(1..data.num_intervals()).into_par_iter()
|
let ci_with_bounds: Vec<(f64, (f64, f64), f64, (f64, f64))> = (1..data.num_intervals()).into_par_iter()
|
||||||
.for_each(|j| {
|
.map(|j| {
|
||||||
if p[j - 1] <= 0.0001 { // To see if the survival probability at the j-th time index is the same as (j-1)-th, check the (j-1)-th failure probability
|
if p[j - 1] <= 0.0001 { // To see if the survival probability at the j-th time index is the same as (j-1)-th, check the (j-1)-th failure probability
|
||||||
return;
|
return (f64::NAN, (f64::NAN, f64::NAN), f64::NAN, (f64::NAN, f64::NAN));
|
||||||
}
|
}
|
||||||
|
|
||||||
let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, None);
|
let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, None);
|
||||||
let mut r = ci_with_bounds[j - 1].write().unwrap();
|
|
||||||
*r = ci;
|
|
||||||
|
|
||||||
progress_bar.inc(1);
|
progress_bar.inc(1);
|
||||||
});
|
return ci; // (CI left, (CI left lower, CI left upper), CI right, (CI right lower, CI right upper))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let ci_with_bounds = Arc::new(ci_with_bounds);
|
||||||
|
|
||||||
// Fill initial guesses for intervals with zero failure probability
|
// Fill initial guesses for intervals with zero failure probability
|
||||||
let mut initial_guesses = Vec::with_capacity(data.num_intervals() - 1);
|
let mut initial_guesses = Vec::with_capacity(data.num_intervals() - 1);
|
||||||
for j in 1..data.num_intervals() {
|
for j in 1..data.num_intervals() {
|
||||||
if p[j - 1] > 0.0001 {
|
if p[j - 1] > 0.0001 {
|
||||||
let r = ci_with_bounds[j - 1].read().unwrap();
|
initial_guesses.push(Some((ci_with_bounds[j - 1].1, ci_with_bounds[j - 1].3)));
|
||||||
initial_guesses.push(Some((r.1, r.3)));
|
|
||||||
} else if j >= 2 {
|
} else if j >= 2 {
|
||||||
initial_guesses.push(initial_guesses[j - 2]); // Carry forward final bounds from last time point
|
initial_guesses.push(initial_guesses[j - 2]); // Carry forward final bounds from last time point
|
||||||
} else {
|
} else {
|
||||||
@ -284,24 +279,21 @@ pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_i
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now do intervals with zero failure probability
|
// Now do intervals with zero failure probability
|
||||||
(1..data.num_intervals()).into_par_iter()
|
let ci_with_bounds: Vec<(f64, (f64, f64), f64, (f64, f64))> = (1..data.num_intervals()).into_par_iter()
|
||||||
.for_each(|j| {
|
.map(|j| {
|
||||||
if p[j - 1] > 0.0001 {
|
if p[j - 1] > 0.0001 {
|
||||||
return;
|
return ci_with_bounds[j - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, initial_guesses[j - 1]);
|
let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, initial_guesses[j - 1]);
|
||||||
let mut r = ci_with_bounds[j - 1].write().unwrap();
|
|
||||||
*r = ci;
|
|
||||||
|
|
||||||
progress_bar.inc(1);
|
progress_bar.inc(1);
|
||||||
});
|
return ci;
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
let confidence_intervals = ci_with_bounds.iter()
|
let confidence_intervals = ci_with_bounds.iter()
|
||||||
.map(|x| {
|
.map(|x| (x.0, x.2))
|
||||||
let r = x.read().unwrap();
|
|
||||||
(r.0, r.2)
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
survival_prob_ci = Some(confidence_intervals);
|
survival_prob_ci = Some(confidence_intervals);
|
||||||
|
Loading…
Reference in New Issue
Block a user