turnbull: Refactor root-finding code

This commit is contained in:
RunasSudo 2023-12-26 18:29:24 +11:00
parent 307aff6f14
commit 760e3bbb0e
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
3 changed files with 81 additions and 26 deletions

View File

@ -3,4 +3,5 @@ pub mod turnbull;
mod csv; mod csv;
mod pava; mod pava;
mod root_finding;
mod term; mod term;

61
src/root_finding.rs Normal file
View File

@ -0,0 +1,61 @@
// hpstat: High-performance statistics implementations
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
pub struct BisectionRootFinder {
bound_lower: f64,
bound_upper: f64,
value_lower: f64,
value_upper: f64
}
impl BisectionRootFinder {
pub fn new(bound_lower: f64, bound_upper: f64, value_lower: f64, value_upper: f64,) -> BisectionRootFinder {
return BisectionRootFinder {
bound_lower: bound_lower,
bound_upper: bound_upper,
value_lower: value_lower,
value_upper: value_upper
}
}
pub fn update(&mut self, guess: f64, value: f64) {
if value > 0.0 {
if self.value_lower > 0.0 || self.value_upper < 0.0 {
self.bound_lower = guess;
self.value_lower = value;
} else {
self.bound_upper = guess;
self.value_upper = value;
}
} else {
if self.value_lower < 0.0 || self.value_upper > 0.0 {
self.bound_lower = guess;
self.value_lower = value;
} else {
self.bound_upper = guess;
self.value_upper = value;
}
}
}
pub fn next_guess(&self) -> f64 {
return (self.bound_lower + self.bound_upper) / 2.0;
}
pub fn precision(&self) -> f64 {
return self.bound_upper - self.bound_lower;
}
}

View File

@ -29,6 +29,7 @@ use serde::{Serialize, Deserialize};
use crate::csv::read_csv; use crate::csv::read_csv;
use crate::pava::monotonic_regression_pava; use crate::pava::monotonic_regression_pava;
use crate::root_finding::BisectionRootFinder;
use crate::term::UnconditionalTermLike; use crate::term::UnconditionalTermLike;
#[derive(Args)] #[derive(Args)]
@ -621,11 +622,14 @@ fn compute_hessian(data: &TurnbullData, p: &Vec<f64>) -> DMatrix<f64> {
fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec<f64>, ll_model: f64, s: &Vec<f64>, oim_se: &Vec<f64>, time_index: usize) -> (f64, f64) { fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec<f64>, ll_model: f64, s: &Vec<f64>, oim_se: &Vec<f64>, time_index: usize) -> (f64, f64) {
// Compute lower confidence limit // Compute lower confidence limit
let mut ci_bound_lower = 0.0; let mut root_finder = BisectionRootFinder::new(
let mut ci_bound_upper = s[time_index]; 0.0, s[time_index],
f64::NAN, -CHI2_1DF_95 // Value of (lr_statistic - CHI2_1DF_95), which we are seeking the roots of
);
let mut ci_estimate = s[time_index] - Z_97_5 * oim_se[time_index - 1]; let mut ci_estimate = s[time_index] - Z_97_5 * oim_se[time_index - 1];
if ci_estimate < 0.0 { if ci_estimate < 0.0 {
ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; ci_estimate = root_finder.next_guess(); // Returns interval midpoint in this case
} }
let mut iteration = 1; let mut iteration = 1;
@ -639,17 +643,10 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate }));
let lr_statistic = 2.0 * (ll_model - ll_test); let lr_statistic = 2.0 * (ll_model - ll_test);
if lr_statistic > CHI2_1DF_95 { root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95);
// CI is too wide ci_estimate = root_finder.next_guess();
ci_bound_lower = ci_estimate;
} else {
// CI is too narrow
ci_bound_upper = ci_estimate;
}
ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; if root_finder.precision() <= ci_precision {
if ci_bound_upper - ci_bound_lower <= ci_precision {
// Desired precision has been reached // Desired precision has been reached
break; break;
} }
@ -663,11 +660,14 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
let ci_lower = ci_estimate; let ci_lower = ci_estimate;
// Compute upper confidence limit // Compute upper confidence limit
ci_bound_lower = s[time_index]; root_finder = BisectionRootFinder::new(
ci_bound_upper = 1.0; s[time_index], 1.0,
-CHI2_1DF_95, f64::NAN
);
ci_estimate = s[time_index] + Z_97_5 * oim_se[time_index - 1]; ci_estimate = s[time_index] + Z_97_5 * oim_se[time_index - 1];
if ci_estimate > 1.0 { if ci_estimate > 1.0 {
ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; ci_estimate = root_finder.next_guess();
} }
let mut iteration = 1; let mut iteration = 1;
@ -681,17 +681,10 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate }));
let lr_statistic = 2.0 * (ll_model - ll_test); let lr_statistic = 2.0 * (ll_model - ll_test);
if lr_statistic > CHI2_1DF_95 { root_finder.update(ci_estimate, lr_statistic - CHI2_1DF_95);
// CI is too wide ci_estimate = root_finder.next_guess();
ci_bound_upper = ci_estimate;
} else {
// CI is too narrow
ci_bound_lower = ci_estimate;
}
ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; if root_finder.precision() <= ci_precision {
if ci_bound_upper - ci_bound_lower <= ci_precision {
// Desired precision has been reached // Desired precision has been reached
break; break;
} }