diff --git a/src/turnbull.rs b/src/turnbull.rs index 71aa40c..cf7b7bf 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -56,6 +56,10 @@ pub struct TurnbullArgs { /// Threshold for dropping failure probability in --se-method oim-drop-zeros #[arg(long, default_value="0.0001")] zero_tolerance: f64, + + /// Desired precision of confidence limits in --se-method likelihood-ratio + #[arg(long, default_value="0.01")] + ci_precision: f64, } #[derive(ValueEnum, Clone)] @@ -78,7 +82,7 @@ pub fn main(args: TurnbullArgs) { // Fit regression let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); - let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.ll_tolerance, args.se_method, args.zero_tolerance); + let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.ll_tolerance, args.se_method, args.zero_tolerance, args.ci_precision); // Display output match args.output { @@ -184,7 +188,7 @@ struct Constraint { survival_prob: f64, } -pub fn fit_turnbull(data_times: Matrix2xX, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult { +pub fn fit_turnbull(data_times: Matrix2xX, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64, ci_precision: f64) -> TurnbullResult { // ---------------------- // Prepare for regression @@ -245,7 +249,7 @@ pub fn fit_turnbull(data_times: Matrix2xX, progress_bar: ProgressBar, max_i progress_bar.println("Computing confidence intervals by likelihood ratio test"); let confidence_intervals = (1..data.num_intervals()).into_par_iter() - .map(|j| survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, &p, ll, &s, &oim_se, j)) + .map(|j| survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, ci_precision, &p, ll, &s, &oim_se, j)) .progress_with(progress_bar.clone()) .collect(); @@ -607,7 +611,7 @@ fn compute_hessian(data: &TurnbullData, p: &Vec) -> DMatrix { return hessian; } -fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, p: &Vec, ll_model: f64, s: &Vec, oim_se: &Vec, time_index: usize) -> (f64, f64) { +fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, ci_precision: f64, p: &Vec, ll_model: f64, s: &Vec, oim_se: &Vec, time_index: usize) -> (f64, f64) { // Compute lower confidence limit let mut ci_bound_lower = 0.0; let mut ci_bound_upper = s[time_index]; @@ -627,10 +631,7 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let lr_statistic = 2.0 * (ll_model - ll_test); - if (lr_statistic - CHI2_1DF_95).abs() < ll_tolerance { - // Converged! - break; - } else if lr_statistic > CHI2_1DF_95 { + if lr_statistic > CHI2_1DF_95 { // CI is too wide ci_bound_lower = ci_estimate; } else { @@ -638,11 +639,13 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress ci_bound_upper = ci_estimate; } - // FIXME: Sometimes this does not converge within max_iterations - investigate why this is - // If it is not an issue with the algorithm, then we should also terminate if width of interval is narrower than a specified tolerance - ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + if ci_bound_upper - ci_bound_lower <= ci_precision { + // Desired precision has been reached + break; + } + iteration += 1; if iteration > max_iterations { panic!("Exceeded --max-iterations"); @@ -670,10 +673,7 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); let lr_statistic = 2.0 * (ll_model - ll_test); - if (lr_statistic - CHI2_1DF_95).abs() < ll_tolerance { - // Converged! - break; - } else if lr_statistic > CHI2_1DF_95 { + if lr_statistic > CHI2_1DF_95 { // CI is too wide ci_bound_upper = ci_estimate; } else { @@ -683,6 +683,11 @@ fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: Progress ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + if ci_bound_upper - ci_bound_lower <= ci_precision { + // Desired precision has been reached + break; + } + iteration += 1; if iteration > max_iterations { panic!("Exceeded --max-iterations");