From 1497e2d5cb364a841a8175ee4fc686350d4642f1 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 23 Apr 2023 18:36:28 +1000 Subject: [PATCH] Add --ll_tolerance parameter for intcox --- src/intcox.rs | 51 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/src/intcox.rs b/src/intcox.rs index 7494925..09be28b 100644 --- a/src/intcox.rs +++ b/src/intcox.rs @@ -44,7 +44,11 @@ pub struct IntCoxArgs { /// Terminate E-M algorithm when the maximum absolute change in all parameters is less than this tolerance #[arg(long, default_value="0.001")] - tolerance: f64, + param_tolerance: f64, + + /// Terminate E-M algorithm when the absolute change in log-likelihood is less than this tolerance + #[arg(long)] + ll_tolerance: Option, /// Estimate baseline hazard function using Turnbull innermost intervals #[arg(long)] @@ -63,7 +67,7 @@ pub fn main(args: IntCoxArgs) { // Fit regression let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); - let result = fit_interval_censored_cox(data_times, data_indep, args.max_iterations, args.tolerance, args.reduced, progress_bar); + let result = fit_interval_censored_cox(data_times, data_indep, args.max_iterations, args.param_tolerance, args.ll_tolerance, args.reduced, progress_bar); // Display output match args.output { @@ -172,7 +176,7 @@ impl IntervalCensoredCoxData { } } -pub fn fit_interval_censored_cox(data_times: DMatrix, mut data_indep: DMatrix, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult { +pub fn fit_interval_censored_cox(data_times: DMatrix, mut data_indep: DMatrix, max_iterations: u32, param_tolerance: f64, ll_tolerance: Option, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult { // ---------------------- // Prepare for regression @@ -256,6 +260,7 @@ pub fn fit_interval_censored_cox(data_times: DMatrix, mut data_indep: DMatr progress_bar.println("Running E-M algorithm to fit interval-censored Cox model"); let mut iteration: u32 = 0; + let mut ll_model: f64 = 0.0; loop { // Pre-compute exp(β^T * Z_ik) let exp_beta_z: Matrix1xX = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); }); @@ -266,17 +271,31 @@ pub fn fit_interval_censored_cox(data_times: DMatrix, mut data_indep: DMatr // Do M-step let (new_beta, new_lambda) = do_m_step(&data, &exp_beta_z, &beta, posterior_weight); - // Check for convergence - let (coef_change, converged) = em_check_convergence(&beta, &lambda, &new_beta, &new_lambda, tolerance); + // Check for convergence (param_tolerance) + let (coef_change, mut converged) = em_check_convergence(&beta, &lambda, &new_beta, &new_lambda, param_tolerance); beta = new_beta; lambda = new_lambda; - // Update progress bar - // Estimate progress according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations - let progress1 = ((-coef_change.log10()).max(0.0) / -tolerance.log10() * u64::MAX as f64) as u64; + // Estimate progress bar according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations + let progress1 = ((-coef_change.log10()).max(0.0) / -param_tolerance.log10() * u64::MAX as f64) as u64; let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64; - progress_bar.set_position(progress_bar.position().max(progress1).max(progress2)); - progress_bar.set_message(format!("Iter {} (delta = {:.6})", iteration + 1, coef_change)); + + if let Some(ll_tolerance_amount) = ll_tolerance { + // Check for convergence (ll_tolerance) + let new_ll = log_likelihood_obs(&data, &beta, &lambda).sum(); + let ll_change = new_ll - ll_model; + converged = converged && (ll_change < ll_tolerance_amount); + ll_model = new_ll; + + // Update progress bar + let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance_amount.log10() * u64::MAX as f64) as u64; + progress_bar.set_position(progress_bar.position().max(progress1.min(progress3)).max(progress2)); + progress_bar.set_message(format!("Iteration {} (Δparams = {:.6}, ΔLL = {:.4})", iteration + 1, coef_change, ll_change)); + } else { + // Update progress bar + progress_bar.set_position(progress_bar.position().max(progress1).max(progress2)); + progress_bar.set_message(format!("Iteration {} (Δparams = {:.6})", iteration + 1, coef_change)); + } if converged { progress_bar.finish(); @@ -290,7 +309,11 @@ pub fn fit_interval_censored_cox(data_times: DMatrix, mut data_indep: DMatr } // Compute log-likelihood - let ll_model = log_likelihood_obs(&data, &beta, &lambda).sum(); + if let Some(_) = ll_tolerance { + // Already computed above + } else { + ll_model = log_likelihood_obs(&data, &beta, &lambda).sum(); + } // Unstandardise betas let mut beta_unstandardised: DVector = DVector::zeros(data.num_covs()); @@ -313,15 +336,15 @@ pub fn fit_interval_censored_cox(data_times: DMatrix, mut data_indep: DMatr // pll_toggle_zero = log-likelihoods for each observation at final beta // pll_toggle_one = log-likelihoods for each observation at toggled beta - let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, tolerance).sum(); + let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, param_tolerance).sum(); - let pll_toggle_zero: DVector = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, tolerance); + let pll_toggle_zero: DVector = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, param_tolerance); progress_bar.inc(1); let pll_toggle_one: Vec> = (0..data.num_covs()).into_par_iter().map(|j| { let mut pll_beta = beta.clone(); pll_beta[j] += h; - profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, tolerance) + profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, param_tolerance) }) .progress_with(progress_bar.clone()) .collect();