From b23ff26eac4dd25a4ff10f68af5fe85eaecbfc05 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 29 Oct 2023 15:07:40 +1100 Subject: [PATCH] turnbull: Implement CIs by likelihood ratio test --- docs/turnbull.tex | 6 +- src/turnbull.rs | 259 +++++++++++++++++++++++++++++++++++++--------- tests/turnbull.rs | 15 +-- 3 files changed, 222 insertions(+), 58 deletions(-) diff --git a/docs/turnbull.tex b/docs/turnbull.tex index 1b7d9e7..d6fc495 100644 --- a/docs/turnbull.tex +++ b/docs/turnbull.tex @@ -49,7 +49,11 @@ % The sum of all $\nablasub{\hat{\symbf{F}}} \mathcal{L}_i$ yields the Hessian of the log-likelihood $\nablasub{\hat{\symbf{F}}} \mathcal{L}$. - The covariance matrix of $\hat{\symbf{F}}$ is given by the inverse of $-\nablasub{\hat{\symbf{F}}} \mathcal{L}$. The standard errors for each of $\hat{\symbf{F}}$ are the square roots of the diagonal elements of the covariance matrix, as required. Alternatively, when \textit{--se-method oim-drop-zeros} is passed, columns/rows of $\nablasub{\hat{\symbf{F}}} \mathcal{L}$ corresponding with intervals where $\hat{s}_i = 0$ are dropped before the matrix is inverted, which enables greater numerical stability but whose theoretical justification is not well explored [3]. + The covariance matrix of $\hat{\symbf{F}}$ is given by the inverse of $-\nablasub{\hat{\symbf{F}}} \mathcal{L}$. The standard errors for each of $\hat{\symbf{F}}$ are the square roots of the diagonal elements of the covariance matrix, as required. + + Alternatively, when \textit{--se-method oim-drop-zeros} is passed, columns/rows of $\nablasub{\hat{\symbf{F}}} \mathcal{L}$ corresponding with intervals where $\hat{s}_i = 0$ are dropped before the matrix is inverted, which enables greater numerical stability but whose theoretical justification is not well explored [3]. + + In the further alternative, when \textit{--se-method likelihood-ratio} is passed, confidence intervals for $\hat{\symbf{F}}$ are computed by inverting a likelihood ratio test at each point, as described by Goodall, Dunn \& Babiker~[3]. {\vspace{0.5cm}\scshape\centering References\par} %{\pagebreak\scshape\centering References\par} diff --git a/src/turnbull.rs b/src/turnbull.rs index f536f3b..d220247 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -15,13 +15,14 @@ // along with this program. If not, see . const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64 +const CHI2_1DF_95: f64 = 3.8414588; use core::mem::MaybeUninit; use std::io; use clap::{Args, ValueEnum}; use csv::{Reader, StringRecord}; -use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; +use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2}; use prettytable::{Table, format, row}; use rayon::prelude::*; @@ -67,6 +68,7 @@ enum OutputFormat { pub enum SEMethod { OIM, OIMDropZeros, + LikelihoodRatio, } pub fn main(args: TurnbullArgs) { @@ -91,18 +93,37 @@ pub fn main(args: TurnbullArgs) { .build(); summary.set_format(format); - summary.set_titles(row!["Time", c->"Surv. Prob.", c->"Std Err.", H2c->"(95% CI)"]); - summary.add_row(row![r->"0.000", r->"1.00000", "", "", ""]); - for (i, prob) in result.survival_prob.iter().enumerate() { - summary.add_row(row![ - r->format!("{:.3}", result.failure_intervals[i].1), - r->format!("{:.5}", prob), - r->format!("{:.5}", result.survival_prob_se[i]), - r->format!("({:.5},", prob - Z_97_5 * result.survival_prob_se[i]), - format!("{:.5})", prob + Z_97_5 * result.survival_prob_se[i]), - ]); + if let Some(survival_prob_se) = &result.survival_prob_se { + // Standard errors available + summary.set_titles(row!["Time", c->"Surv. Prob.", c->"Std Err.", H2c->"(95% CI)"]); + summary.add_row(row![r->"0.000", r->"1.00000", "", "", ""]); + for (i, prob) in result.survival_prob.iter().enumerate() { + summary.add_row(row![ + r->format!("{:.3}", result.failure_intervals[i].1), + r->format!("{:.5}", prob), + r->format!("{:.5}", survival_prob_se[i]), + r->format!("({:.5},", prob - Z_97_5 * survival_prob_se[i]), + format!("{:.5})", prob + Z_97_5 * survival_prob_se[i]), + ]); + } + summary.add_row(row![r->format!("{:.3}", result.failure_intervals.last().unwrap().1), r->"0.00000", "", "", ""]); + } else { + // No standard errors, just print CIs + let survival_prob_ci = result.survival_prob_ci.as_ref().unwrap(); + + summary.set_titles(row!["Time", c->"Surv. Prob.", H2c->"(95% CI)"]); + summary.add_row(row![r->"0.000", r->"1.00000", "", ""]); + for (i, prob) in result.survival_prob.iter().enumerate() { + summary.add_row(row![ + r->format!("{:.3}", result.failure_intervals[i].1), + r->format!("{:.5}", prob), + r->format!("({:.5},", survival_prob_ci[i].0), + format!("{:.5})", survival_prob_ci[i].1), + ]); + } + summary.add_row(row![r->format!("{:.3}", result.failure_intervals.last().unwrap().1), r->"0.00000", "", ""]); } - summary.add_row(row![r->format!("{:.3}", result.failure_intervals.last().unwrap().1), r->"0.00000", "", "", ""]); + summary.printstd(); } OutputFormat::Json => { @@ -170,6 +191,12 @@ impl TurnbullData { } } +/// Constrains the survival probability at a particular time s[time_index] == survival_prob +struct Constraint { + time_index: usize, + survival_prob: f64, +} + pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult { // ---------------------- // Prepare for regression @@ -195,7 +222,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i // Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec let p = vec![1.0 / intervals.len() as f64; intervals.len()]; - let mut data = TurnbullData { + let data = TurnbullData { data_time_interval_indexes: data_time_interval_indexes, intervals: intervals, }; @@ -208,7 +235,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i progress_bar.reset(); progress_bar.println("Running EM-ICM algorithm to fit Turnbull estimator"); - let (p, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, p); + let (p, ll) = fit_turnbull_estimator(&data, progress_bar.clone(), max_iterations, ll_tolerance, p, None); // Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0) let mut survival_prob: Vec = Vec::with_capacity(data.num_intervals() - 1); @@ -221,42 +248,31 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i // -------------------------------------------------- // Compute standard errors for survival probabilities - let hessian = compute_hessian(&data, &p); - - let mut survival_prob_se: DVector; + let mut survival_prob_se = None; + let mut survival_prob_ci = None; match se_method { SEMethod::OIM => { - // Compute covariance matrix as inverse of negative Hessian - let vcov = -hessian.try_inverse().expect("Matrix not invertible"); - survival_prob_se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); }); + survival_prob_se = Some(survival_prob_oim_se(&data, &p, zero_tolerance, false)); } SEMethod::OIMDropZeros => { - // Drop rows/columns of Hessian corresponding to intervals with zero failure probability - let nonzero_intervals: Vec = (0..(data.num_intervals() - 1)).filter(|i| p[*i] > zero_tolerance).collect(); + survival_prob_se = Some(survival_prob_oim_se(&data, &p, zero_tolerance, true)); + } + SEMethod::LikelihoodRatio => { + let s = p_to_s(&p); + let oim_se = survival_prob_oim_se(&data, &p, zero_tolerance, true); - let mut hessian_nonzero: DMatrix = DMatrix::zeros(nonzero_intervals.len(), nonzero_intervals.len()); - for (nonzero_index1, orig_index1) in nonzero_intervals.iter().enumerate() { - hessian_nonzero[(nonzero_index1, nonzero_index1)] = hessian[(*orig_index1, *orig_index1)]; - for (nonzero_index2, orig_index2) in nonzero_intervals.iter().enumerate().take(nonzero_index1) { - hessian_nonzero[(nonzero_index1, nonzero_index2)] = hessian[(*orig_index1, *orig_index2)]; - hessian_nonzero[(nonzero_index2, nonzero_index1)] = hessian[(*orig_index2, *orig_index1)]; - } - } + progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} CI {pos}/{len}").unwrap()); + progress_bar.set_length(data.num_intervals() as u64 - 1); + progress_bar.reset(); + progress_bar.println("Computing confidence intervals by likelihood ratio test"); - let vcov = -hessian_nonzero.try_inverse().expect("Matrix not invertible"); - let survival_prob_se_nonzero = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); }); + let confidence_intervals = (1..data.num_intervals()).into_par_iter() + .map(|j| survival_prob_likelihood_ratio_ci(&data, ProgressBar::hidden(), max_iterations, ll_tolerance, &p, ll, &s, &oim_se, j)) + .progress_with(progress_bar.clone()) + .collect(); - survival_prob_se = DVector::zeros(data.num_intervals() - 1); - let mut nonzero_index = 0; - for orig_index in 0..(data.num_intervals() - 1) { - if nonzero_intervals.contains(&orig_index) { - survival_prob_se[orig_index] = survival_prob_se_nonzero[nonzero_index]; - nonzero_index += 1; - } else { - survival_prob_se[orig_index] = survival_prob_se[orig_index - 1]; - } - } + survival_prob_ci = Some(confidence_intervals); } } @@ -264,7 +280,8 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i failure_intervals: data.intervals, failure_prob: p, survival_prob: survival_prob, - survival_prob_se: survival_prob_se.data.as_vec().clone(), + survival_prob_se: survival_prob_se, + survival_prob_ci: survival_prob_ci, ll_model: ll, }; } @@ -286,7 +303,7 @@ fn get_turnbull_intervals(data_times: &MatrixXx2) -> Vec<(f64, f64)> { return intervals; } -fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec) -> (Vec, f64) { +fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec, constraint: Option) -> (Vec, f64) { // Pre-compute S, the survival probability at the start of each interval let mut s = p_to_s(&p); @@ -299,7 +316,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma // ------- // EM step - let p_after_em = do_em_step(data, &p, &s); + let p_after_em = do_em_step(data, &p, &s, &constraint); let s_after_em = p_to_s(&p_after_em); let likelihood_obs_after_em = get_likelihood_obs(data, &s_after_em); @@ -311,7 +328,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma // -------- // ICM step - let (p_new, s_new, ll_model_new) = do_icm_step(data, &p, &s, ll_tolerance, ll_model_after_em); + let (p_new, s_new, ll_model_new) = do_icm_step(data, &p, &s, ll_tolerance, &constraint, ll_model_after_em); let ll_change = ll_model_new - ll_model; let converged = ll_change <= ll_tolerance; @@ -369,7 +386,7 @@ fn get_likelihood_obs(data: &TurnbullData, s: &Vec) -> Vec { .collect(); // TODO: Return iterator directly } -fn do_em_step(data: &TurnbullData, p: &Vec, s: &Vec) -> Vec { +fn do_em_step(data: &TurnbullData, p: &Vec, s: &Vec, constraint: &Option) -> Vec { // Compute contributions to m let mut m_contrib = vec![0.0; data.num_intervals()]; for (idx_left, idx_right) in data.data_time_interval_indexes.iter() { @@ -395,12 +412,20 @@ fn do_em_step(data: &TurnbullData, p: &Vec, s: &Vec) -> Vec { // Update p // p := p * m - let p_new = p.par_iter().zip(m.into_par_iter()).map(|(p_j, m_j)| p_j * m_j).collect(); + let mut p_new: Vec = p.par_iter().zip(m.into_par_iter()).map(|(p_j, m_j)| p_j * m_j).collect(); + + // Constrain if required + if let Some(c) = &constraint { + let cur_fail_prob: f64 = p_new[0..c.time_index].iter().copied().sum(); + // Not sure why borrow checker thinks there is an unused borrow here... + let _ = &mut p_new[0..c.time_index].iter_mut().for_each(|x| *x *= (1.0 - c.survival_prob) / cur_fail_prob); // Desired failure probability over current failure probability + let _ = &mut p_new[c.time_index..].iter_mut().for_each(|x| *x *= c.survival_prob / (1.0 - cur_fail_prob)); + } return p_new; } -fn do_icm_step(data: &TurnbullData, p: &Vec, s: &Vec, ll_tolerance: f64, ll_model: f64) -> (Vec, Vec, f64) { +fn do_icm_step(data: &TurnbullData, p: &Vec, s: &Vec, ll_tolerance: f64, constraint: &Option, ll_model: f64) -> (Vec, Vec, f64) { // Compute Λ, the cumulative hazard // Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted // The entry at lambda[j] corresponds to the survival immediately before time point j + 1 @@ -474,6 +499,17 @@ fn do_icm_step(data: &TurnbullData, p: &Vec, s: &Vec, ll_tolerance: f6 ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); if ll_model_new > ll_model { + // Constrain if required + if let Some(c) = constraint { + let cur_survival_prob = s_new[c.time_index]; + let _ = &mut p_new[0..c.time_index].iter_mut().for_each(|x| *x *= (1.0 - c.survival_prob) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability + let _ = &mut p_new[c.time_index..].iter_mut().for_each(|x| *x *= c.survival_prob / cur_survival_prob); + + s_new = p_to_s(&p_new); + let likelihood_obs_new = get_likelihood_obs(data, &s_new); + ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); + } + return (p_new, s_new, ll_model_new); } @@ -491,6 +527,44 @@ fn do_icm_step(data: &TurnbullData, p: &Vec, s: &Vec, ll_tolerance: f6 } } +fn survival_prob_oim_se(data: &TurnbullData, p: &Vec, zero_tolerance: f64, drop_zeros: bool) -> Vec { + let hessian = compute_hessian(&data, &p); + + if drop_zeros { + // Drop rows/columns of Hessian corresponding to intervals with zero failure probability + let nonzero_intervals: Vec = (0..(data.num_intervals() - 1)).filter(|i| p[*i] > zero_tolerance).collect(); + + let mut hessian_nonzero: DMatrix = DMatrix::zeros(nonzero_intervals.len(), nonzero_intervals.len()); + for (nonzero_index1, orig_index1) in nonzero_intervals.iter().enumerate() { + hessian_nonzero[(nonzero_index1, nonzero_index1)] = hessian[(*orig_index1, *orig_index1)]; + for (nonzero_index2, orig_index2) in nonzero_intervals.iter().enumerate().take(nonzero_index1) { + hessian_nonzero[(nonzero_index1, nonzero_index2)] = hessian[(*orig_index1, *orig_index2)]; + hessian_nonzero[(nonzero_index2, nonzero_index1)] = hessian[(*orig_index2, *orig_index1)]; + } + } + + let vcov = -hessian_nonzero.try_inverse().expect("Matrix not invertible"); + let survival_prob_se_nonzero = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); }); + + let mut se = vec![0.0; data.num_intervals() - 1]; + let mut nonzero_index = 0; + for orig_index in 0..(data.num_intervals() - 1) { + if nonzero_intervals.contains(&orig_index) { + se[orig_index] = survival_prob_se_nonzero[nonzero_index]; + nonzero_index += 1; + } else { + se[orig_index] = se[orig_index - 1]; + } + } + return se; + } else { + // Compute covariance matrix as inverse of negative Hessian + let vcov = -hessian.try_inverse().expect("Matrix not invertible"); + let se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); }); + return se.data.as_vec().clone(); + } +} + fn compute_hessian(data: &TurnbullData, p: &Vec) -> DMatrix { let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); @@ -539,11 +613,96 @@ fn compute_hessian(data: &TurnbullData, p: &Vec) -> DMatrix { return hessian; } +fn survival_prob_likelihood_ratio_ci(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, p: &Vec, ll_model: f64, s: &Vec, oim_se: &Vec, time_index: usize) -> (f64, f64) { + // Compute lower confidence limit + let mut ci_bound_lower = 0.0; + let mut ci_bound_upper = s[time_index]; + let mut ci_estimate = s[time_index] - Z_97_5 * oim_se[time_index - 1]; + if ci_estimate < 0.0 { + ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + } + + let mut iteration = 1; + loop { + // Get starting guess, constrained at time_index + let mut p_test = p.clone(); + let cur_survival_prob = s[time_index]; + let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - ci_estimate) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability + let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= ci_estimate / cur_survival_prob); + + let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); + let lr_statistic = 2.0 * (ll_model - ll_test); + + if (lr_statistic - CHI2_1DF_95).abs() < ll_tolerance { + // Converged! + break; + } else if lr_statistic > CHI2_1DF_95 { + // CI is too wide + ci_bound_lower = ci_estimate; + } else { + // CI is too narrow + ci_bound_upper = ci_estimate; + } + + ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + + iteration += 1; + if iteration > max_iterations { + panic!("Exceeded --max-iterations"); + } + } + + let ci_lower = ci_estimate; + + // Compute upper confidence limit + ci_bound_lower = s[time_index]; + ci_bound_upper = 1.0; + ci_estimate = s[time_index] + Z_97_5 * oim_se[time_index - 1]; + if ci_estimate > 1.0 { + ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + } + + let mut iteration = 1; + loop { + // Get starting guess, constrained at time_index + let mut p_test = p.clone(); + let cur_survival_prob = s[time_index]; + let _ = &mut p_test[0..time_index].iter_mut().for_each(|x| *x *= (1.0 - ci_estimate) / (1.0 - cur_survival_prob)); // Desired failure probability over current failure probability + let _ = &mut p_test[time_index..].iter_mut().for_each(|x| *x *= ci_estimate / cur_survival_prob); + + let (_p, ll_test) = fit_turnbull_estimator(data, progress_bar.clone(), max_iterations, ll_tolerance, p_test, Some(Constraint { time_index: time_index, survival_prob: ci_estimate })); + let lr_statistic = 2.0 * (ll_model - ll_test); + + if (lr_statistic - CHI2_1DF_95).abs() < ll_tolerance { + // Converged! + break; + } else if lr_statistic > CHI2_1DF_95 { + // CI is too wide + ci_bound_upper = ci_estimate; + } else { + // CI is too narrow + ci_bound_lower = ci_estimate; + } + + ci_estimate = (ci_bound_lower + ci_bound_upper) / 2.0; + + iteration += 1; + if iteration > max_iterations { + panic!("Exceeded --max-iterations"); + } + } + + let ci_upper = ci_estimate; + + return (ci_lower, ci_upper); +} + #[derive(Serialize, Deserialize)] pub struct TurnbullResult { pub failure_intervals: Vec<(f64, f64)>, pub failure_prob: Vec, pub survival_prob: Vec, - pub survival_prob_se: Vec, + pub survival_prob_se: Option>, + pub survival_prob_ci: Option>, pub ll_model: f64, } diff --git a/tests/turnbull.rs b/tests/turnbull.rs index 41499a3..8a92cab 100644 --- a/tests/turnbull.rs +++ b/tests/turnbull.rs @@ -54,13 +54,14 @@ fn test_turnbull_minitab() { assert!(abs_diff(result.survival_prob[5], 0.431840) < 0.000001); assert!(abs_diff(result.survival_prob[6], 0.200191) < 0.000001); - assert!(abs_diff(result.survival_prob_se[0], 0.0016488) < 0.0000001); - assert!(abs_diff(result.survival_prob_se[1], 0.0035430) < 0.0000001); - assert!(abs_diff(result.survival_prob_se[2], 0.0064517) < 0.0000001); - assert!(abs_diff(result.survival_prob_se[3], 0.0109856) < 0.0000001); - assert!(abs_diff(result.survival_prob_se[4], 0.0143949) < 0.0000001); - assert!(abs_diff(result.survival_prob_se[5], 0.0152936) < 0.0000001); - assert!(abs_diff(result.survival_prob_se[6], 0.0123546) < 0.0000001); + let survival_prob_se = result.survival_prob_se.as_ref().unwrap(); + assert!(abs_diff(survival_prob_se[0], 0.0016488) < 0.0000001); + assert!(abs_diff(survival_prob_se[1], 0.0035430) < 0.0000001); + assert!(abs_diff(survival_prob_se[2], 0.0064517) < 0.0000001); + assert!(abs_diff(survival_prob_se[3], 0.0109856) < 0.0000001); + assert!(abs_diff(survival_prob_se[4], 0.0143949) < 0.0000001); + assert!(abs_diff(survival_prob_se[5], 0.0152936) < 0.0000001); + assert!(abs_diff(survival_prob_se[6], 0.0123546) < 0.0000001); } fn abs_diff(a: f64, b: f64) -> f64 {