Add --ll_tolerance parameter for intcox
This commit is contained in:
parent
d0d92f2a78
commit
1497e2d5cb
@ -44,7 +44,11 @@ pub struct IntCoxArgs {
|
|||||||
|
|
||||||
/// Terminate E-M algorithm when the maximum absolute change in all parameters is less than this tolerance
|
/// Terminate E-M algorithm when the maximum absolute change in all parameters is less than this tolerance
|
||||||
#[arg(long, default_value="0.001")]
|
#[arg(long, default_value="0.001")]
|
||||||
tolerance: f64,
|
param_tolerance: f64,
|
||||||
|
|
||||||
|
/// Terminate E-M algorithm when the absolute change in log-likelihood is less than this tolerance
|
||||||
|
#[arg(long)]
|
||||||
|
ll_tolerance: Option<f64>,
|
||||||
|
|
||||||
/// Estimate baseline hazard function using Turnbull innermost intervals
|
/// Estimate baseline hazard function using Turnbull innermost intervals
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -63,7 +67,7 @@ pub fn main(args: IntCoxArgs) {
|
|||||||
|
|
||||||
// Fit regression
|
// Fit regression
|
||||||
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
|
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
|
||||||
let result = fit_interval_censored_cox(data_times, data_indep, args.max_iterations, args.tolerance, args.reduced, progress_bar);
|
let result = fit_interval_censored_cox(data_times, data_indep, args.max_iterations, args.param_tolerance, args.ll_tolerance, args.reduced, progress_bar);
|
||||||
|
|
||||||
// Display output
|
// Display output
|
||||||
match args.output {
|
match args.output {
|
||||||
@ -172,7 +176,7 @@ impl IntervalCensoredCoxData {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult {
|
pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, param_tolerance: f64, ll_tolerance: Option<f64>, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult {
|
||||||
// ----------------------
|
// ----------------------
|
||||||
// Prepare for regression
|
// Prepare for regression
|
||||||
|
|
||||||
@ -256,6 +260,7 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
|
|||||||
progress_bar.println("Running E-M algorithm to fit interval-censored Cox model");
|
progress_bar.println("Running E-M algorithm to fit interval-censored Cox model");
|
||||||
|
|
||||||
let mut iteration: u32 = 0;
|
let mut iteration: u32 = 0;
|
||||||
|
let mut ll_model: f64 = 0.0;
|
||||||
loop {
|
loop {
|
||||||
// Pre-compute exp(β^T * Z_ik)
|
// Pre-compute exp(β^T * Z_ik)
|
||||||
let exp_beta_z: Matrix1xX<f64> = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); });
|
let exp_beta_z: Matrix1xX<f64> = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); });
|
||||||
@ -266,17 +271,31 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
|
|||||||
// Do M-step
|
// Do M-step
|
||||||
let (new_beta, new_lambda) = do_m_step(&data, &exp_beta_z, &beta, posterior_weight);
|
let (new_beta, new_lambda) = do_m_step(&data, &exp_beta_z, &beta, posterior_weight);
|
||||||
|
|
||||||
// Check for convergence
|
// Check for convergence (param_tolerance)
|
||||||
let (coef_change, converged) = em_check_convergence(&beta, &lambda, &new_beta, &new_lambda, tolerance);
|
let (coef_change, mut converged) = em_check_convergence(&beta, &lambda, &new_beta, &new_lambda, param_tolerance);
|
||||||
beta = new_beta;
|
beta = new_beta;
|
||||||
lambda = new_lambda;
|
lambda = new_lambda;
|
||||||
|
|
||||||
// Update progress bar
|
// Estimate progress bar according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations
|
||||||
// Estimate progress according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations
|
let progress1 = ((-coef_change.log10()).max(0.0) / -param_tolerance.log10() * u64::MAX as f64) as u64;
|
||||||
let progress1 = ((-coef_change.log10()).max(0.0) / -tolerance.log10() * u64::MAX as f64) as u64;
|
|
||||||
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
|
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
|
||||||
|
|
||||||
|
if let Some(ll_tolerance_amount) = ll_tolerance {
|
||||||
|
// Check for convergence (ll_tolerance)
|
||||||
|
let new_ll = log_likelihood_obs(&data, &beta, &lambda).sum();
|
||||||
|
let ll_change = new_ll - ll_model;
|
||||||
|
converged = converged && (ll_change < ll_tolerance_amount);
|
||||||
|
ll_model = new_ll;
|
||||||
|
|
||||||
|
// Update progress bar
|
||||||
|
let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance_amount.log10() * u64::MAX as f64) as u64;
|
||||||
|
progress_bar.set_position(progress_bar.position().max(progress1.min(progress3)).max(progress2));
|
||||||
|
progress_bar.set_message(format!("Iteration {} (Δparams = {:.6}, ΔLL = {:.4})", iteration + 1, coef_change, ll_change));
|
||||||
|
} else {
|
||||||
|
// Update progress bar
|
||||||
progress_bar.set_position(progress_bar.position().max(progress1).max(progress2));
|
progress_bar.set_position(progress_bar.position().max(progress1).max(progress2));
|
||||||
progress_bar.set_message(format!("Iter {} (delta = {:.6})", iteration + 1, coef_change));
|
progress_bar.set_message(format!("Iteration {} (Δparams = {:.6})", iteration + 1, coef_change));
|
||||||
|
}
|
||||||
|
|
||||||
if converged {
|
if converged {
|
||||||
progress_bar.finish();
|
progress_bar.finish();
|
||||||
@ -290,7 +309,11 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute log-likelihood
|
// Compute log-likelihood
|
||||||
let ll_model = log_likelihood_obs(&data, &beta, &lambda).sum();
|
if let Some(_) = ll_tolerance {
|
||||||
|
// Already computed above
|
||||||
|
} else {
|
||||||
|
ll_model = log_likelihood_obs(&data, &beta, &lambda).sum();
|
||||||
|
}
|
||||||
|
|
||||||
// Unstandardise betas
|
// Unstandardise betas
|
||||||
let mut beta_unstandardised: DVector<f64> = DVector::zeros(data.num_covs());
|
let mut beta_unstandardised: DVector<f64> = DVector::zeros(data.num_covs());
|
||||||
@ -313,15 +336,15 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
|
|||||||
// pll_toggle_zero = log-likelihoods for each observation at final beta
|
// pll_toggle_zero = log-likelihoods for each observation at final beta
|
||||||
// pll_toggle_one = log-likelihoods for each observation at toggled beta
|
// pll_toggle_one = log-likelihoods for each observation at toggled beta
|
||||||
|
|
||||||
let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, tolerance).sum();
|
let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, param_tolerance).sum();
|
||||||
|
|
||||||
let pll_toggle_zero: DVector<f64> = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, tolerance);
|
let pll_toggle_zero: DVector<f64> = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, param_tolerance);
|
||||||
progress_bar.inc(1);
|
progress_bar.inc(1);
|
||||||
|
|
||||||
let pll_toggle_one: Vec<DVector<f64>> = (0..data.num_covs()).into_par_iter().map(|j| {
|
let pll_toggle_one: Vec<DVector<f64>> = (0..data.num_covs()).into_par_iter().map(|j| {
|
||||||
let mut pll_beta = beta.clone();
|
let mut pll_beta = beta.clone();
|
||||||
pll_beta[j] += h;
|
pll_beta[j] += h;
|
||||||
profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, tolerance)
|
profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, param_tolerance)
|
||||||
})
|
})
|
||||||
.progress_with(progress_bar.clone())
|
.progress_with(progress_bar.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
Loading…
Reference in New Issue
Block a user