turnbull: Use smarter initial guesses for likelihood-ratio confidence intervals
When the survival probability at a point is the same as the previous point, the confidence interval should be similar So re-use the final bracketing interval as the initial guess to save time in the root-finding 150% speedup!
This commit is contained in:
parent
b569956de7
commit
204571d6cb
@ -102,6 +102,10 @@ impl AndersonBjorckRootFinder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn bounds(&self) -> (f64, f64) {
|
||||||
|
return (self.bound_lower, self.bound_upper);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn precision(&self) -> f64 {
|
pub fn precision(&self) -> f64 {
|
||||||
return (self.bound_upper - self.bound_lower).abs();
|
return (self.bound_upper - self.bound_lower).abs();
|
||||||
}
|
}
|
||||||
|
151
src/turnbull.rs
151
src/turnbull.rs
@ -19,9 +19,10 @@ const CHI2_1DF_95: f64 = 3.8414588;
|
|||||||
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufReader};
|
use std::io::{self, BufReader};
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use clap::{Args, ValueEnum};
|
use clap::{Args, ValueEnum};
|
||||||
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||||
use nalgebra::{DMatrix, DVector, Matrix2xX};
|
use nalgebra::{DMatrix, DVector, Matrix2xX};
|
||||||
use prettytable::{Table, format, row};
|
use prettytable::{Table, format, row};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
@ -249,9 +250,58 @@ pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_i
|
|||||||
progress_bar.reset();
|
progress_bar.reset();
|
||||||
progress_bar.println("Computing confidence intervals by likelihood ratio test");
|
progress_bar.println("Computing confidence intervals by likelihood ratio test");
|
||||||
|
|
||||||
let confidence_intervals = (1..data.num_intervals()).into_par_iter()
|
// (CI left, (CI left lower, CI left upper), CI right, (CI right lower, CI right upper))
|
||||||
.map(|j| survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j))
|
// TODO: Refactor this (unsafe code?) - each thread reads/writes only one value so there is no need for locking
|
||||||
.progress_with(progress_bar.clone())
|
let ci_with_bounds = Arc::new(
|
||||||
|
Vec::from_iter((1..data.num_intervals()).map(|_| RwLock::new((f64::NAN, (f64::NAN, f64::NAN), f64::NAN, (f64::NAN, f64::NAN)))))
|
||||||
|
);
|
||||||
|
|
||||||
|
// First do intervals with nonzero failure probability
|
||||||
|
(1..data.num_intervals()).into_par_iter()
|
||||||
|
.for_each(|j| {
|
||||||
|
if p[j - 1] <= 0.0001 { // To see if the survival probability at the j-th time index is the same as (j-1)-th, check the (j-1)-th failure probability
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, None);
|
||||||
|
let mut r = ci_with_bounds[j - 1].write().unwrap();
|
||||||
|
*r = ci;
|
||||||
|
|
||||||
|
progress_bar.inc(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Fill initial guesses for intervals with zero failure probability
|
||||||
|
let mut initial_guesses = Vec::with_capacity(data.num_intervals() - 1);
|
||||||
|
for j in 1..data.num_intervals() {
|
||||||
|
if p[j - 1] > 0.0001 {
|
||||||
|
let r = ci_with_bounds[j - 1].read().unwrap();
|
||||||
|
initial_guesses.push(Some((r.1, r.3)));
|
||||||
|
} else if j >= 2 {
|
||||||
|
initial_guesses.push(initial_guesses[j - 2]); // Carry forward final bounds from last time point
|
||||||
|
} else {
|
||||||
|
initial_guesses.push(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now do intervals with zero failure probability
|
||||||
|
(1..data.num_intervals()).into_par_iter()
|
||||||
|
.for_each(|j| {
|
||||||
|
if p[j - 1] > 0.0001 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ci = survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j, initial_guesses[j - 1]);
|
||||||
|
let mut r = ci_with_bounds[j - 1].write().unwrap();
|
||||||
|
*r = ci;
|
||||||
|
|
||||||
|
progress_bar.inc(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
let confidence_intervals = ci_with_bounds.iter()
|
||||||
|
.map(|x| {
|
||||||
|
let r = x.read().unwrap();
|
||||||
|
(r.0, r.2)
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
survival_prob_ci = Some(confidence_intervals);
|
survival_prob_ci = Some(confidence_intervals);
|
||||||
@ -620,8 +670,10 @@ fn compute_hessian(data: &TurnbullData, p: &Vec<f64>) -> DMatrix<f64> {
|
|||||||
return hessian;
|
return hessian;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec<f64>, ll_model: f64, s: &Vec<f64>, oim_se: &Vec<f64>, time_index: usize) -> (f64, f64) {
|
fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec<f64>, ll_model: f64, s: &Vec<f64>, oim_se: &Vec<f64>, time_index: usize, initial_guess: Option<((f64, f64), (f64, f64))>) -> (f64, (f64, f64), f64, (f64, f64)) {
|
||||||
|
// ------------------------------
|
||||||
// Compute lower confidence limit
|
// Compute lower confidence limit
|
||||||
|
|
||||||
let mut root_finder = AndersonBjorckRootFinder::new(
|
let mut root_finder = AndersonBjorckRootFinder::new(
|
||||||
0.0, s[time_index],
|
0.0, s[time_index],
|
||||||
f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of
|
f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of
|
||||||
@ -632,22 +684,37 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
|
|||||||
ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case
|
ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use initial guess if available
|
||||||
|
if let Some(((initial_left, initial_right), _)) = initial_guess {
|
||||||
|
let value_left = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_left)) - CHI2_1DF_95;
|
||||||
|
let value_right = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_right)) - CHI2_1DF_95;
|
||||||
|
|
||||||
|
if value_left * value_right < 0.0 {
|
||||||
|
// Different signs, therefore this is a valid bracketing interval
|
||||||
|
root_finder = AndersonBjorckRootFinder::new(
|
||||||
|
initial_left, initial_right,
|
||||||
|
value_left, value_right // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of
|
||||||
|
);
|
||||||
|
|
||||||
|
ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut iteration = 1;
|
let mut iteration = 1;
|
||||||
loop {
|
loop {
|
||||||
// Get starting guess, constrained at time_index
|
if root_finder.precision() <= ci_precision {
|
||||||
let mut p_test = p.clone();
|
// Desired precision has been reached
|
||||||
let cur_survival_prob = s[time_index];
|
// We check this first so that if an initial guess is supplied, we can terminate immediately here if it is sufficiently good
|
||||||
let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - ci_estimate) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability
|
break;
|
||||||
let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= ci_estimate / cur_survival_prob);
|
}
|
||||||
|
|
||||||
let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate }));
|
let ll_test = profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, ci_estimate);
|
||||||
let lr_statistic = 2.0 * (ll_model - ll_test);
|
let lr_statistic = 2.0 * (ll_model - ll_test);
|
||||||
|
|
||||||
root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95);
|
root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95);
|
||||||
ci_estimate = root_finder.next_guess();
|
ci_estimate = root_finder.next_guess();
|
||||||
|
|
||||||
if root_finder.precision() <= ci_precision {
|
if (lr_statistic - CHI2_1DF_95).abs() <= ll_tolerance {
|
||||||
// Desired precision has been reached
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -658,34 +725,53 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
|
|||||||
}
|
}
|
||||||
|
|
||||||
let ci_lower = ci_estimate;
|
let ci_lower = ci_estimate;
|
||||||
|
let ci_lower_bounds = root_finder.bounds();
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
// Compute upper confidence limit
|
// Compute upper confidence limit
|
||||||
|
|
||||||
root_finder = AndersonBjorckRootFinder::new(
|
root_finder = AndersonBjorckRootFinder::new(
|
||||||
s[time_index], 1.0,
|
0.0, s[time_index],
|
||||||
-CHI2_1DF_95, f64::NAN // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of
|
f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of
|
||||||
);
|
);
|
||||||
|
|
||||||
ci_estimate = s[time_index] + Z_97_5 * oim_se[time_index - 1];
|
ci_estimate = s[time_index] - Z_97_5 * oim_se[time_index - 1];
|
||||||
if ci_estimate > 1.0 {
|
if ci_estimate < 0.0 {
|
||||||
ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case
|
ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use initial guess if available
|
||||||
|
if let Some((_, (initial_left, initial_right))) = initial_guess {
|
||||||
|
let value_left = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_left)) - CHI2_1DF_95;
|
||||||
|
let value_right = 2.0 * (ll_model - profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, initial_right)) - CHI2_1DF_95;
|
||||||
|
|
||||||
|
if value_left * value_right < 0.0 {
|
||||||
|
// Different signs, therefore this is a valid bracketing interval
|
||||||
|
root_finder = AndersonBjorckRootFinder::new(
|
||||||
|
initial_left, initial_right,
|
||||||
|
value_left, value_right // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: Terminate if reached precision already
|
||||||
|
|
||||||
|
ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut iteration = 1;
|
let mut iteration = 1;
|
||||||
loop {
|
loop {
|
||||||
// Get starting guess, constrained at time_index
|
if root_finder.precision() <= ci_precision {
|
||||||
let mut p_test = p.clone();
|
// Desired precision has been reached
|
||||||
let cur_survival_prob = s[time_index];
|
break;
|
||||||
let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - ci_estimate) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability
|
}
|
||||||
let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= ci_estimate / cur_survival_prob);
|
|
||||||
|
|
||||||
let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate }));
|
let ll_test = profile_likelihood_survival_prob(data, &progress_bar, max_iterations, ll_tolerance, p, s, time_index, ci_estimate);
|
||||||
let lr_statistic = 2.0 * (ll_model - ll_test);
|
let lr_statistic = 2.0 * (ll_model - ll_test);
|
||||||
|
|
||||||
root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95);
|
root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95);
|
||||||
ci_estimate = root_finder.next_guess();
|
ci_estimate = root_finder.next_guess();
|
||||||
|
|
||||||
if root_finder.precision() <= ci_precision {
|
if (lr_statistic - CHI2_1DF_95).abs() <= ll_tolerance {
|
||||||
// Desired precision has been reached
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -696,8 +782,21 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
|
|||||||
}
|
}
|
||||||
|
|
||||||
let ci_upper = ci_estimate;
|
let ci_upper = ci_estimate;
|
||||||
|
let ci_upper_bounds = root_finder.bounds();
|
||||||
|
|
||||||
return (ci_lower, ci_upper);
|
return (ci_lower, ci_lower_bounds, ci_upper, ci_upper_bounds);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn profile_likelihood_survival_prob(data: &TurnbullData, progress_bar: &ProgressBar, max_iterations: u32, ll_tolerance: f64, p: &Vec<f64>, s: &Vec<f64>, time_index: usize, survival_prob: f64) -> f64 {
|
||||||
|
// Get starting guess, constrained at time_index
|
||||||
|
let mut p_test = p.clone();
|
||||||
|
let cur_survival_prob = s[time_index];
|
||||||
|
let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - survival_prob) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability
|
||||||
|
let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= survival_prob / cur_survival_prob);
|
||||||
|
|
||||||
|
let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: survival_prob }));
|
||||||
|
|
||||||
|
return ll_test;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
|
Loading…
Reference in New Issue
Block a user