diff --git a/src/lib.rs b/src/lib.rs index 11e3925..376f503 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,4 +3,5 @@ pub mod turnbull; mod csv; mod pava; +mod root_finding; mod term; diff --git a/src/root_finding.rs b/src/root_finding.rs new file mode 100644 index 0000000..99c10d4 --- /dev/null +++ b/src/root_finding.rs @@ -0,0 +1,61 @@ +// hpstat: High-performance statistics implementations +// Copyright © 2023 Lee Yingtong Li (RunasSudo) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +pub struct BisectionRootFinder { + bound_lower: f64, + bound_upper: f64, + value_lower: f64, + value_upper: f64 +} + +impl BisectionRootFinder { + pub fn new(bound_lower: f64, bound_upper: f64, value_lower: f64, value_upper: f64,) -> BisectionRootFinder { + return BisectionRootFinder { + bound_lower: bound_lower, + bound_upper: bound_upper, + value_lower: value_lower, + value_upper: value_upper + } + } + + pub fn update(&mut self, guess: f64, value: f64) { + if value > 0.0 { + if self.value_lower > 0.0 || self.value_upper < 0.0 { + self.bound_lower = guess; + self.value_lower = value; + } else { + self.bound_upper = guess; + self.value_upper = value; + } + } else { + if self.value_lower < 0.0 || self.value_upper > 0.0 { + self.bound_lower = guess; + self.value_lower = value; + } else { + self.bound_upper = guess; + self.value_upper = value; + } + } + } + + pub fn next_guess(&self) -> f64 { + return (self.bound_lower + self.bound_upper) / 2.0; + } + + pub fn precision(&self) -> f64 { + return self.bound_upper - self.bound_lower; + } +} diff --git a/src/turnbull.rs b/src/turnbull.rs index 3c956e9..3b71566 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -29,6 +29,7 @@ use serde::{Serialize, Deserialize}; use crate::csv::read_csv; use crate::pava::monotonic_regression_pava; +use crate::root_finding::BisectionRootFinder; use crate::term::UnconditionalTermLike; #[derive(Args)] @@ -621,11 +622,14 @@ fn compute_hessian(data: &TurnbullData, p: &Vec) -> DMatrix { fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec, ll_model: f64, s: &Vec, oim_se: &Vec, time_index: usize) -> (f64, f64) { // Compute lower confidence limit - let mut ci_bound_lower = 0.0; - let mut ci_bound_upper = s[time_index]; + let mut root_finder = BisectionRootFinder::new( + 0.0, s[time_index], + f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of + ); + let mut ci_estimate = s[time_index] - Z_97_5 * oim_se[time_index - 1]; if ci_estimate < 0.0 { - ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case } let mut iteration = 1; @@ -639,17 +643,10 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let lr_statistic = 2.0 * (ll_model - ll_test); - if lr_statistic > CHI2_1DF_95 { - // CI is too wide - ci_bound_lower = ci_estimate; - } else { - // CI is too narrow - ci_bound_upper = ci_estimate; - } + root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95); + ci_estimate = root_finder.next_guess(); - ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; - - if ci_bound_upper - ci_bound_lower <= ci_precision { + if root_finder.precision() <= ci_precision { // Desired precision has been reached break; } @@ -663,11 +660,14 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress let ci_lower = ci_estimate; // Compute upper confidence limit - ci_bound_lower = s[time_index]; - ci_bound_upper = 1.0; + root_finder = BisectionRootFinder::new( + s[time_index], 1.0, + -CHI2_1DF_95, f64::NAN + ); + ci_estimate = s[time_index] + Z_97_5 * oim_se[time_index - 1]; if ci_estimate > 1.0 { - ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + ci_estimate = root_finder.next_guess(); } let mut iteration = 1; @@ -681,17 +681,10 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let lr_statistic = 2.0 * (ll_model - ll_test); - if lr_statistic > CHI2_1DF_95 { - // CI is too wide - ci_bound_upper = ci_estimate; - } else { - // CI is too narrow - ci_bound_lower = ci_estimate; - } + root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95); + ci_estimate = root_finder.next_guess(); - ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; - - if ci_bound_upper - ci_bound_lower <= ci_precision { + if root_finder.precision() <= ci_precision { // Desired precision has been reached break; }