From 81b0b3f9b54066502531267996226fa73fdf3adb Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sat, 28 Oct 2023 23:08:03 +1100 Subject: [PATCH] turnbull: Pre-compute survival probabilities --- src/turnbull.rs | 69 +++++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/src/turnbull.rs b/src/turnbull.rs index d685ce9..d5b93f1 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -287,8 +287,11 @@ fn get_turnbull_intervals(data_times: &MatrixXx2) -> Vec<(f64, f64)> { } fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec) -> (Vec, f64) { + // Pre-compute S, the survival probability at the start of each interval + let mut s = p_to_s(&p); + // Get likelihood for each observation - let mut likelihood_obs = get_likelihood_obs(data, &p); + let mut likelihood_obs = get_likelihood_obs(data, &s); let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum(); let mut iteration = 1; @@ -296,55 +299,41 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma // ------- // EM step - // Pre-compute S, the survival probability at the start of each interval - let mut s = Vec::with_capacity(data.num_intervals() + 1); - let mut survival = 1.0; - s.push(1.0); - for p_j in p.iter() { - survival -= p_j; - s.push(survival); - } - // Update p let mut p_new = Vec::with_capacity(data.num_intervals()); for j in 0..data.num_intervals() { let tmp: f64 = data.data_time_interval_indexes.iter() .filter(|(idx_left, idx_right)| j >= *idx_left && j <= *idx_right) - //.map(|(idx_left, idx_right)| 1.0 / p[*idx_left..(*idx_right + 1)].iter().sum::()) .map(|(idx_left, idx_right)| 1.0 / (s[*idx_left] - s[*idx_right + 1])) .sum(); p_new.push(p[j] * tmp / (data.num_obs() as f64)); } - let likelihood_obs_after_em = get_likelihood_obs(data, &p_new); + let mut s_new = p_to_s(&p_new); + let likelihood_obs_after_em = get_likelihood_obs(data, &s_new); let ll_model_after_em: f64 = likelihood_obs_after_em.iter().map(|l| l.ln()).sum(); p = p_new; + s = s_new; // -------- // ICM step - // Compute Λ - // S = 1 means Λ = -inf and S = 0 means Λ = inf so skip these - let mut lambda = Vec::with_capacity(data.num_intervals() - 1); - let mut survival: f64 = 1.0; - for j in 0..(data.num_intervals() - 1) { - survival -= p[j]; - lambda.push((-survival.ln()).ln()); - } + // Compute Λ, the cumulative hazard + let lambda = s_to_lambda(&s); // Compute gradient let mut gradient = DVector::zeros(data.num_intervals() - 1); for j in 0..(data.num_intervals() - 1) { let sum_right: f64 = data.data_time_interval_indexes.iter() .filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1) - .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / p[*idx_left..(*idx_right + 1)].iter().sum::()) + .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1])) .sum(); let sum_left: f64 = data.data_time_interval_indexes.iter() .filter(|(idx_left, idx_right)| j + 1 == *idx_left) - .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / p[*idx_left..(*idx_right + 1)].iter().sum::()) + .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1])) .sum(); gradient[j] = sum_right - sum_left; @@ -356,7 +345,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma let sum_left: f64 = data.data_time_interval_indexes.iter() .filter(|(idx_left, idx_right)| j + 1 == *idx_left) .map(|(idx_left, idx_right)| { - let denom: f64 = p[*idx_left..(*idx_right + 1)].iter().sum(); + let denom = s[*idx_left] - s[*idx_right + 1]; let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); -a - b @@ -366,7 +355,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma let sum_right: f64 = data.data_time_interval_indexes.iter() .filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1) .map(|(idx_left, idx_right)| { - let denom: f64 = p[*idx_left..(*idx_right + 1)].iter().sum(); + let denom = s[*idx_left] - s[*idx_right + 1]; let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); a - b @@ -395,17 +384,21 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma let lambda_new = monotonic_regression_pava(lambda_target, weights.clone()); // Convert Λ to S to p + s_new = Vec::with_capacity(data.num_intervals() + 1); p_new = Vec::with_capacity(data.num_intervals()); let mut survival = 1.0; + s_new.push(1.0); for lambda_j in lambda_new.iter() { let next_survival = (-lambda_j.exp()).exp(); + s_new.push(next_survival); p_new.push(survival - next_survival); survival = next_survival; } + s_new.push(0.0); p_new.push(survival); - let likelihood_obs_new = get_likelihood_obs(data, &p_new); + let likelihood_obs_new = get_likelihood_obs(data, &s_new); ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); if ll_model_new > ll_model_after_em { @@ -423,7 +416,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma let converged = ll_change <= ll_tolerance; p = p_new; - //likelihood_obs = likelihood_obs_new; + s = s_new; ll_model = ll_model_new; // Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations @@ -448,11 +441,31 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma return (p, ll_model); } +fn p_to_s(p: &Vec) -> Vec { + let mut s = Vec::with_capacity(p.len() + 1); // Survival probabilities + let mut survival = 1.0; + s.push(1.0); + for p_j in p.iter() { + survival -= p_j; + s.push(survival); + } + return s; +} + +fn s_to_lambda(s: &Vec) -> Vec { + // S = 1 means Λ = -inf and S = 0 means Λ = inf so skip these + let mut lambda = Vec::with_capacity(s.len() - 2); // Cumulative hazard + for s_j in &s[1..(s.len() - 1)] { + lambda.push((-s_j.ln()).ln()); + } + return lambda; +} + fn get_likelihood_obs(data: &TurnbullData, s: &Vec) -> Vec { return data.data_time_interval_indexes .par_iter() - .map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum()) - .collect(); + .map(|(idx_left, idx_right)| s[*idx_left] - s[*idx_right + 1]) + .collect(); // TODO: Return iterator directly } fn compute_hessian(data: &TurnbullData, s: &Vec) -> DMatrix {