Add --ll_tolerance parameter for intcox

This commit is contained in:
RunasSudo 2023-04-23 18:36:28 +10:00
parent d0d92f2a78
commit 1497e2d5cb
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A

View File

@ -44,7 +44,11 @@ pub struct IntCoxArgs {
/// Terminate E-M algorithm when the maximum absolute change in all parameters is less than this tolerance /// Terminate E-M algorithm when the maximum absolute change in all parameters is less than this tolerance
#[arg(long, default_value="0.001")] #[arg(long, default_value="0.001")]
tolerance: f64, param_tolerance: f64,
/// Terminate E-M algorithm when the absolute change in log-likelihood is less than this tolerance
#[arg(long)]
ll_tolerance: Option<f64>,
/// Estimate baseline hazard function using Turnbull innermost intervals /// Estimate baseline hazard function using Turnbull innermost intervals
#[arg(long)] #[arg(long)]
@ -63,7 +67,7 @@ pub fn main(args: IntCoxArgs) {
// Fit regression // Fit regression
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
let result = fit_interval_censored_cox(data_times, data_indep, args.max_iterations, args.tolerance, args.reduced, progress_bar); let result = fit_interval_censored_cox(data_times, data_indep, args.max_iterations, args.param_tolerance, args.ll_tolerance, args.reduced, progress_bar);
// Display output // Display output
match args.output { match args.output {
@ -172,7 +176,7 @@ impl IntervalCensoredCoxData {
} }
} }
pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult { pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, param_tolerance: f64, ll_tolerance: Option<f64>, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult {
// ---------------------- // ----------------------
// Prepare for regression // Prepare for regression
@ -256,6 +260,7 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
progress_bar.println("Running E-M algorithm to fit interval-censored Cox model"); progress_bar.println("Running E-M algorithm to fit interval-censored Cox model");
let mut iteration: u32 = 0; let mut iteration: u32 = 0;
let mut ll_model: f64 = 0.0;
loop { loop {
// Pre-compute exp(β^T * Z_ik) // Pre-compute exp(β^T * Z_ik)
let exp_beta_z: Matrix1xX<f64> = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); }); let exp_beta_z: Matrix1xX<f64> = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); });
@ -266,17 +271,31 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
// Do M-step // Do M-step
let (new_beta, new_lambda) = do_m_step(&data, &exp_beta_z, &beta, posterior_weight); let (new_beta, new_lambda) = do_m_step(&data, &exp_beta_z, &beta, posterior_weight);
// Check for convergence // Check for convergence (param_tolerance)
let (coef_change, converged) = em_check_convergence(&beta, &lambda, &new_beta, &new_lambda, tolerance); let (coef_change, mut converged) = em_check_convergence(&beta, &lambda, &new_beta, &new_lambda, param_tolerance);
beta = new_beta; beta = new_beta;
lambda = new_lambda; lambda = new_lambda;
// Update progress bar // Estimate progress bar according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations
// Estimate progress according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations let progress1 = ((-coef_change.log10()).max(0.0) / -param_tolerance.log10() * u64::MAX as f64) as u64;
let progress1 = ((-coef_change.log10()).max(0.0) / -tolerance.log10() * u64::MAX as f64) as u64;
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64; let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
if let Some(ll_tolerance_amount) = ll_tolerance {
// Check for convergence (ll_tolerance)
let new_ll = log_likelihood_obs(&data, &beta, &lambda).sum();
let ll_change = new_ll - ll_model;
converged = converged && (ll_change < ll_tolerance_amount);
ll_model = new_ll;
// Update progress bar
let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance_amount.log10() * u64::MAX as f64) as u64;
progress_bar.set_position(progress_bar.position().max(progress1.min(progress3)).max(progress2));
progress_bar.set_message(format!("Iteration {} (Δparams = {:.6}, ΔLL = {:.4})", iteration + 1, coef_change, ll_change));
} else {
// Update progress bar
progress_bar.set_position(progress_bar.position().max(progress1).max(progress2)); progress_bar.set_position(progress_bar.position().max(progress1).max(progress2));
progress_bar.set_message(format!("Iter {} (delta = {:.6})", iteration + 1, coef_change)); progress_bar.set_message(format!("Iteration {} (Δparams = {:.6})", iteration + 1, coef_change));
}
if converged { if converged {
progress_bar.finish(); progress_bar.finish();
@ -290,7 +309,11 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
} }
// Compute log-likelihood // Compute log-likelihood
let ll_model = log_likelihood_obs(&data, &beta, &lambda).sum(); if let Some(_) = ll_tolerance {
// Already computed above
} else {
ll_model = log_likelihood_obs(&data, &beta, &lambda).sum();
}
// Unstandardise betas // Unstandardise betas
let mut beta_unstandardised: DVector<f64> = DVector::zeros(data.num_covs()); let mut beta_unstandardised: DVector<f64> = DVector::zeros(data.num_covs());
@ -313,15 +336,15 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
// pll_toggle_zero = log-likelihoods for each observation at final beta // pll_toggle_zero = log-likelihoods for each observation at final beta
// pll_toggle_one = log-likelihoods for each observation at toggled beta // pll_toggle_one = log-likelihoods for each observation at toggled beta
let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, tolerance).sum(); let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, param_tolerance).sum();
let pll_toggle_zero: DVector<f64> = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, tolerance); let pll_toggle_zero: DVector<f64> = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, param_tolerance);
progress_bar.inc(1); progress_bar.inc(1);
let pll_toggle_one: Vec<DVector<f64>> = (0..data.num_covs()).into_par_iter().map(|j| { let pll_toggle_one: Vec<DVector<f64>> = (0..data.num_covs()).into_par_iter().map(|j| {
let mut pll_beta = beta.clone(); let mut pll_beta = beta.clone();
pll_beta[j] += h; pll_beta[j] += h;
profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, tolerance) profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, param_tolerance)
}) })
.progress_with(progress_bar.clone()) .progress_with(progress_bar.clone())
.collect(); .collect();