From 759c2c47781680d765da9467613adea66601b3cb Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Tue, 26 Dec 2023 19:18:38 +1100 Subject: [PATCH] turnbull: Use Illinois method rather than interval bisection for likelihood-ratio confidence intervals 27% speedup NB: Regula falsi alone without Illinois adjustment was slower than interval bisection --- src/root_finding.rs | 45 ++++++++++++++++++++++++++++++++++++--------- src/turnbull.rs | 10 +++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/root_finding.rs b/src/root_finding.rs index 99c10d4..d712b14 100644 --- a/src/root_finding.rs +++ b/src/root_finding.rs @@ -14,48 +14,75 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -pub struct BisectionRootFinder { +pub struct IllinoisRootFinder { bound_lower: f64, bound_upper: f64, value_lower: f64, - value_upper: f64 + value_upper: f64, + last_sign: f64 // Sign of the function at last evaluation (1.0 or -1.0) or 0.0 if first iteration } -impl BisectionRootFinder { - pub fn new(bound_lower: f64, bound_upper: f64, value_lower: f64, value_upper: f64,) -> BisectionRootFinder { - return BisectionRootFinder { +impl IllinoisRootFinder { + pub fn new(bound_lower: f64, bound_upper: f64, value_lower: f64, value_upper: f64) -> IllinoisRootFinder { + return IllinoisRootFinder { bound_lower: bound_lower, bound_upper: bound_upper, value_lower: value_lower, - value_upper: value_upper + value_upper: value_upper, + last_sign: 0.0 } } - pub fn update(&mut self, guess: f64, value: f64) { + pub fn update(&mut self, guess: f64, mut value: f64) { if value > 0.0 { if self.value_lower > 0.0 || self.value_upper < 0.0 { self.bound_lower = guess; self.value_lower = value; + + if self.last_sign == 1.0 { + // Illinois adjustment: Halve the y-value of the retained end point (the other end point) when the new y-value has the same sign as the previous one + self.value_upper *= 0.5; + } } else { self.bound_upper = guess; self.value_upper = value; + + if self.last_sign == 1.0 { + self.value_lower *= 0.5; + } } + self.last_sign = 1.0; } else { if self.value_lower < 0.0 || self.value_upper > 0.0 { self.bound_lower = guess; self.value_lower = value; + + if self.last_sign == -1.0 { + self.value_upper *= 0.5; + } } else { self.bound_upper = guess; self.value_upper = value; + + if self.last_sign == -1.0 { + self.value_lower *= 0.5; + } } + self.last_sign = -1.0; } } pub fn next_guess(&self) -> f64 { - return (self.bound_lower + self.bound_upper) / 2.0; + if self.value_lower.is_nan() || self.value_upper.is_nan() { + // Fall back to interval bisection + return (self.bound_lower + self.bound_upper) / 2.0; + } else { + // Regula falsi + return (self.bound_lower * self.value_upper - self.bound_upper * self.value_lower) / (self.value_upper - self.value_lower); + } } pub fn precision(&self) -> f64 { - return self.bound_upper - self.bound_lower; + return (self.bound_upper - self.bound_lower).abs(); } } diff --git a/src/turnbull.rs b/src/turnbull.rs index 3b71566..3c32b61 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -29,7 +29,7 @@ use serde::{Serialize, Deserialize}; use crate::csv::read_csv; use crate::pava::monotonic_regression_pava; -use crate::root_finding::BisectionRootFinder; +use crate::root_finding::IllinoisRootFinder; use crate::term::UnconditionalTermLike; #[derive(Args)] @@ -622,7 +622,7 @@ fn compute_hessian(data: &TurnbullData, p: &Vec) -> DMatrix { fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec, ll_model: f64, s: &Vec, oim_se: &Vec, time_index: usize) -> (f64, f64) { // Compute lower confidence limit - let mut root_finder = BisectionRootFinder::new( + let mut root_finder = IllinoisRootFinder::new( 0.0, s[time_index], f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of ); @@ -660,14 +660,14 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress let ci_lower = ci_estimate; // Compute upper confidence limit - root_finder = BisectionRootFinder::new( + root_finder = IllinoisRootFinder::new( s[time_index], 1.0, - -CHI2_1DF_95, f64::NAN + -CHI2_1DF_95, f64::NAN // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of ); ci_estimate = s[time_index] + Z_97_5 * oim_se[time_index - 1]; if ci_estimate > 1.0 { - ci_estimate = root_finder.next_guess(); + ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case } let mut iteration = 1;