From e726fad99b16db77a9ba0f5465511b38143b9449 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sat, 28 Oct 2023 23:47:00 +1100 Subject: [PATCH] turnbull: Improve efficiency of ICM step as suggested by Anderson-Bergman --- src/turnbull.rs | 70 +++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/src/turnbull.rs b/src/turnbull.rs index 7e31a02..fd67669 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -402,48 +402,37 @@ fn do_em_step(data: &TurnbullData, p: &Vec, s: &Vec) -> Vec { fn do_icm_step(data: &TurnbullData, _p: &Vec, s: &Vec, ll_model: f64) -> (Vec, Vec, f64) { // Compute Λ, the cumulative hazard + // Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted + // The entry at lambda[j] corresponds to the survival immediately before time point j + 1 let lambda = s_to_lambda(&s); - // Compute gradient - let mut gradient = DVector::zeros(data.num_intervals() - 1); - for j in 0..(data.num_intervals() - 1) { - let sum_right: f64 = data.data_time_interval_indexes.iter() - .filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1) - .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1])) - .sum(); + // Compute gradient and diagonal of Hessian + let mut gradient = vec![0.0; data.num_intervals() - 1]; + let mut hessdiag = vec![0.0; data.num_intervals() - 1]; + for (idx_left, idx_right) in data.data_time_interval_indexes.iter() { + let denom = s[*idx_left] - s[*idx_right + 1]; - let sum_left: f64 = data.data_time_interval_indexes.iter() - .filter(|(idx_left, idx_right)| j + 1 == *idx_left) - .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1])) - .sum(); + // Add to gradient[j] when j + 1 == idx_right + 1 + // Add to hessdiag[j] when j + 1 == idx_right + 1 + if *idx_right < gradient.len() { + let j = *idx_right; + gradient[j] += (-lambda[j].exp() + lambda[j]).exp() / denom; + + let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; + let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); + hessdiag[j] += a - b; + } - gradient[j] = sum_right - sum_left; - } - - // Compute diagonal of Hessian - let mut hessdiag = DVector::zeros(data.num_intervals() - 1); - for j in 0..(data.num_intervals() - 1) { - let sum_left: f64 = data.data_time_interval_indexes.iter() - .filter(|(idx_left, idx_right)| j + 1 == *idx_left) - .map(|(idx_left, idx_right)| { - let denom = s[*idx_left] - s[*idx_right + 1]; - let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; - let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); - -a - b - }) - .sum(); - - let sum_right: f64 = data.data_time_interval_indexes.iter() - .filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1) - .map(|(idx_left, idx_right)| { - let denom = s[*idx_left] - s[*idx_right + 1]; - let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; - let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); - a - b - }) - .sum(); - - hessdiag[j] = sum_left + sum_right; + // Subtract from gradient[j] when j + 1 == idx_left + // Add to hessdiag[j] when j + 1 == idx_left + if *idx_left > 0 { + let j = *idx_left - 1; + gradient[j] -= (-lambda[j].exp() + lambda[j]).exp() / denom; + + let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; + let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); + hessdiag[j] += -a - b; + } } // Description in Anderson-Bergman (2017) is slightly misleading @@ -451,7 +440,8 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec, s: &Vec, ll_model: f64) // And we will move in the direction of the gradient // So there are a few more negative signs here than suggested - let weights = -hessdiag.clone() / 2.0; + let weights = -DVector::from_vec(hessdiag.clone()) / 2.0; + let gradient_over_hessdiag = DVector::from_vec(gradient.par_iter().zip(hessdiag.par_iter()).map(|(g, h)| g / h).collect()); let mut s_new; let mut p_new; @@ -461,7 +451,7 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec, s: &Vec, ll_model: f64) let mut step_size_exponent: i32 = 0; loop { let step_size = 0.5_f64.powi(step_size_exponent); - let lambda_target = -gradient.component_div(&hessdiag) * step_size + DVector::from_vec(lambda.clone()); + let lambda_target = -gradient_over_hessdiag.clone() * step_size + DVector::from_vec(lambda.clone()); let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());