turnbull: Change convergence tolerance to based on log-likelihood
This commit is contained in:
parent
0e39402d3d
commit
79c53895b0
@ -244,11 +244,8 @@ pub fn fit_interval_censored_cox(data_times: MatrixXx2<f64>, mut data_indep: DMa
|
||||
s = compute_s(&data, &lambda_new, &exp_z_beta);
|
||||
let ll_model_new = log_likelihood_obs(&s).sum();
|
||||
|
||||
let mut converged = true;
|
||||
let ll_change = ll_model_new - ll_model;
|
||||
if ll_change > ll_tolerance {
|
||||
converged = false;
|
||||
}
|
||||
let converged = ll_change <= ll_tolerance;
|
||||
|
||||
lambda = lambda_new;
|
||||
beta = beta_new;
|
||||
|
@ -43,9 +43,9 @@ pub struct TurnbullArgs {
|
||||
#[arg(long, default_value="1000")]
|
||||
max_iterations: u32,
|
||||
|
||||
/// Terminate algorithm when the absolute change in failure probability in each interval is less than this tolerance
|
||||
#[arg(long, default_value="0.0001")]
|
||||
fail_prob_tolerance: f64,
|
||||
/// Terminate algorithm when the absolute change in log-likelihood is less than this tolerance
|
||||
#[arg(long, default_value="0.01")]
|
||||
ll_tolerance: f64,
|
||||
|
||||
/// Method for computing standard error or survival probabilities
|
||||
#[arg(long, value_enum, default_value="oim")]
|
||||
@ -74,7 +74,7 @@ pub fn main(args: TurnbullArgs) {
|
||||
|
||||
// Fit regression
|
||||
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
|
||||
let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.fail_prob_tolerance, args.se_method, args.zero_tolerance);
|
||||
let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.ll_tolerance, args.se_method, args.zero_tolerance);
|
||||
|
||||
// Display output
|
||||
match args.output {
|
||||
@ -169,7 +169,7 @@ impl TurnbullData {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
|
||||
pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
|
||||
// ----------------------
|
||||
// Prepare for regression
|
||||
|
||||
@ -207,7 +207,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
||||
progress_bar.reset();
|
||||
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
|
||||
|
||||
let s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s);
|
||||
let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s);
|
||||
|
||||
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
|
||||
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
|
||||
@ -217,9 +217,6 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
||||
survival_prob.push(acc);
|
||||
}
|
||||
|
||||
// Compute log-likelihood
|
||||
let ll = compute_log_likelihood(&data, &s);
|
||||
|
||||
// --------------------------------------------------
|
||||
// Compute standard errors for survival probabilities
|
||||
|
||||
@ -288,28 +285,33 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
||||
return intervals;
|
||||
}
|
||||
|
||||
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: Vec<f64>) -> Vec<f64> {
|
||||
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut s: Vec<f64>) -> (Vec<f64>, f64) {
|
||||
// Get likelihood for each observation (denominator of μ_ij)
|
||||
let mut likelihood_obs = get_likelihood_obs(data, &s);
|
||||
let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum();
|
||||
|
||||
let mut iteration = 1;
|
||||
loop {
|
||||
// Get total failure probability for each observation (denominator of μ_ij)
|
||||
let sum_fail_prob = get_sum_fail_prob(data, &s);
|
||||
// Compute π_j to update s
|
||||
let pi = compute_pi(data, &s, likelihood_obs);
|
||||
|
||||
// Compute π_j
|
||||
let pi = compute_pi(data, &s, sum_fail_prob);
|
||||
let likelihood_obs_new = get_likelihood_obs(data, &pi);
|
||||
let ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
|
||||
|
||||
let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
|
||||
|
||||
let converged = largest_delta_s <= fail_prob_tolerance;
|
||||
let ll_change = ll_model_new - ll_model;
|
||||
let converged = ll_change <= ll_tolerance;
|
||||
|
||||
s = pi;
|
||||
likelihood_obs = likelihood_obs_new;
|
||||
ll_model = ll_model_new;
|
||||
|
||||
// Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations
|
||||
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
|
||||
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
|
||||
let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64;
|
||||
let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance.log10() * u64::MAX as f64) as u64;
|
||||
|
||||
// Update progress bar
|
||||
progress_bar.set_position(progress_bar.position().max(progress3.max(progress2)));
|
||||
progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s));
|
||||
progress_bar.set_message(format!("Iteration {} (LL = {:.4}, ΔLL = {:.4})", iteration + 1, ll_model, ll_change));
|
||||
|
||||
if converged {
|
||||
progress_bar.println(format!("Converged in {} iterations", iteration));
|
||||
@ -322,34 +324,34 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
|
||||
}
|
||||
}
|
||||
|
||||
return s;
|
||||
return (s, ll_model);
|
||||
}
|
||||
|
||||
fn get_sum_fail_prob(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
|
||||
fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
|
||||
return data.data_time_interval_indexes
|
||||
.par_iter()
|
||||
.map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum())
|
||||
.collect();
|
||||
}
|
||||
|
||||
fn compute_pi(data: &TurnbullData, s: &Vec<f64>, sum_fail_prob: Vec<f64>) -> Vec<f64> {
|
||||
fn compute_pi(data: &TurnbullData, s: &Vec<f64>, likelihood_obs: Vec<f64>) -> Vec<f64> {
|
||||
/*
|
||||
let mut pi: Vec<f64> = vec![0.0; data.num_intervals()];
|
||||
for ((idx_left, idx_right), sum_fail_prob_i) in data.data_time_interval_indexes.iter().zip(sum_fail_prob.iter()) {
|
||||
for ((idx_left, idx_right), likelihood_obs_i) in data.data_time_interval_indexes.iter().zip(likelihood_obs.iter()) {
|
||||
for j in *idx_left..(*idx_right + 1) {
|
||||
pi[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64;
|
||||
pi[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
let pi = data.data_time_interval_indexes.par_iter().zip(sum_fail_prob.par_iter())
|
||||
let pi = data.data_time_interval_indexes.par_iter().zip(likelihood_obs.par_iter())
|
||||
.fold_with(
|
||||
// Compute the contributions to pi[j] for each observation and sum them in parallel using fold_with
|
||||
vec![0.0; data.num_intervals()],
|
||||
|mut acc, ((idx_left, idx_right), sum_fail_prob_i)| {
|
||||
|mut acc, ((idx_left, idx_right), likelihood_obs_i)| {
|
||||
// Contributions to pi[j] for the i-th observation
|
||||
for j in *idx_left..(*idx_right + 1) {
|
||||
acc[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64;
|
||||
acc[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
|
||||
}
|
||||
acc
|
||||
}
|
||||
@ -366,17 +368,6 @@ fn compute_pi(data: &TurnbullData, s: &Vec<f64>, sum_fail_prob: Vec<f64>) -> Vec
|
||||
return pi;
|
||||
}
|
||||
|
||||
fn compute_log_likelihood(data: &TurnbullData, s: &Vec<f64>) -> f64 {
|
||||
let mut ll = 0.0;
|
||||
|
||||
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
|
||||
let likelihood_ob: f64 = s[*idx_left..(*idx_right + 1)].iter().sum();
|
||||
ll += likelihood_ob.ln();
|
||||
}
|
||||
|
||||
return ll;
|
||||
}
|
||||
|
||||
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
|
||||
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
|
||||
|
||||
|
@ -26,7 +26,7 @@ fn test_turnbull_minitab() {
|
||||
|
||||
// Fit regression
|
||||
let progress_bar = ProgressBar::hidden();
|
||||
let result = turnbull::fit_turnbull(data_times, progress_bar, 500, 0.0001, turnbull::SEMethod::OIM, 0.0001);
|
||||
let result = turnbull::fit_turnbull(data_times, progress_bar, 500, 0.01, turnbull::SEMethod::OIM, 0.0001);
|
||||
|
||||
assert_eq!(result.failure_intervals[0], (20000.0, 30000.0));
|
||||
assert_eq!(result.failure_intervals[1], (30000.0, 40000.0));
|
||||
|
Loading…
Reference in New Issue
Block a user