turnbull: Terminate search for likelihood-ratio confidence intervals based on precision

81% speedup
Also resolves issue where confidence interval search would occasionally never terminate
This commit is contained in:
RunasSudo 2023-12-25 20:08:40 +11:00
parent c67965478d
commit c9a8b5b8a5
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 20 additions and 15 deletions

View File

@ -56,6 +56,10 @@ pub struct TurnbullArgs {
/// Threshold for dropping failure probability in --se-method oim-drop-zeros /// Threshold for dropping failure probability in --se-method oim-drop-zeros
#[arg(long, default_value="0.0001")] #[arg(long, default_value="0.0001")]
zero_tolerance: f64, zero_tolerance: f64,
/// Desired precision of confidence limits in --se-method likelihood-ratio
#[arg(long, default_value="0.01")]
ci_precision: f64,
} }
#[derive(ValueEnum, Clone)] #[derive(ValueEnum, Clone)]
@ -78,7 +82,7 @@ pub fn main(args: TurnbullArgs) {
// Fit regression // Fit regression
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.ll_tolerance, args.se_method, args.zero_tolerance); let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.ll_tolerance, args.se_method, args.zero_tolerance, args.ci_precision);
// Display output // Display output
match args.output { match args.output {
@ -184,7 +188,7 @@ struct Constraint {
survival_prob: f64, survival_prob: f64,
} }
pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult { pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64, ci_precision: f64) -> TurnbullResult {
// ---------------------- // ----------------------
// Prepare for regression // Prepare for regression
@ -245,7 +249,7 @@ pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_i
progress_bar.println("Computing confidence intervals by likelihood ratio test"); progress_bar.println("Computing confidence intervals by likelihood ratio test");
let confidence_intervals = (1..data.num_intervals()).into_par_iter() let confidence_intervals = (1..data.num_intervals()).into_par_iter()
.map(|j| survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, &p, ll, &s, &oim_se, j)) .map(|j| survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j))
.progress_with(progress_bar.clone()) .progress_with(progress_bar.clone())
.collect(); .collect();
@ -607,7 +611,7 @@ fn compute_hessian(data: &TurnbullData, p: &Vec<f64>) -> DMatrix<f64> {
return hessian; return hessian;
} }
fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, p: &Vec<f64>, ll_model: f64, s: &Vec<f64>, oim_se: &Vec<f64>, time_index: usize) -> (f64, f64) { fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec<f64>, ll_model: f64, s: &Vec<f64>, oim_se: &Vec<f64>, time_index: usize) -> (f64, f64) {
// Compute lower confidence limit // Compute lower confidence limit
let mut ci_bound_lower = 0.0; let mut ci_bound_lower = 0.0;
let mut ci_bound_upper = s[time_index]; let mut ci_bound_upper = s[time_index];
@ -627,10 +631,7 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate }));
let lr_statistic = 2.0 * (ll_model - ll_test); let lr_statistic = 2.0 * (ll_model - ll_test);
if (lr_statistic - CHI2_1DF_95).abs() < ll_tolerance { if lr_statistic > CHI2_1DF_95 {
// Converged!
break;
} else if lr_statistic > CHI2_1DF_95 {
// CI is too wide // CI is too wide
ci_bound_lower = ci_estimate; ci_bound_lower = ci_estimate;
} else { } else {
@ -638,11 +639,13 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
ci_bound_upper = ci_estimate; ci_bound_upper = ci_estimate;
} }
// FIXME: Sometimes this does not converge within max_iterations - investigate why this is
// If it is not an issue with the algorithm, then we should also terminate if width of interval is narrower than a specified tolerance
ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0;
if ci_bound_upper - ci_bound_lower <= ci_precision {
// Desired precision has been reached
break;
}
iteration += 1; iteration += 1;
if iteration > max_iterations { if iteration > max_iterations {
panic!("Exceeded --max-iterations"); panic!("Exceeded --max-iterations");
@ -670,10 +673,7 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate }));
let lr_statistic = 2.0 * (ll_model - ll_test); let lr_statistic = 2.0 * (ll_model - ll_test);
if (lr_statistic - CHI2_1DF_95).abs() < ll_tolerance { if lr_statistic > CHI2_1DF_95 {
// Converged!
break;
} else if lr_statistic > CHI2_1DF_95 {
// CI is too wide // CI is too wide
ci_bound_upper = ci_estimate; ci_bound_upper = ci_estimate;
} else { } else {
@ -683,6 +683,11 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress
ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0;
if ci_bound_upper - ci_bound_lower <= ci_precision {
// Desired precision has been reached
break;
}
iteration += 1; iteration += 1;
if iteration > max_iterations { if iteration > max_iterations {
panic!("Exceeded --max-iterations"); panic!("Exceeded --max-iterations");