turnbull: Improve efficiency of ICM step as suggested by Anderson-Bergman

This commit is contained in:
RunasSudo 2023-10-28 23:47:00 +11:00
parent 85e3ee0dcd
commit e726fad99b
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 30 additions and 40 deletions

View File

@ -402,48 +402,37 @@ fn do_em_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>) -> Vec<f64> {
fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) { fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) {
// Compute Λ, the cumulative hazard // Compute Λ, the cumulative hazard
// Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted
// The entry at lambda[j] corresponds to the survival immediately before time point j + 1
let lambda = s_to_lambda(&s); let lambda = s_to_lambda(&s);
// Compute gradient // Compute gradient and diagonal of Hessian
let mut gradient = DVector::zeros(data.num_intervals() - 1); let mut gradient = vec![0.0; data.num_intervals() - 1];
for j in 0..(data.num_intervals() - 1) { let mut hessdiag = vec![0.0; data.num_intervals() - 1];
let sum_right: f64 = data.data_time_interval_indexes.iter() for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1) let denom = s[*idx_left] - s[*idx_right + 1];
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1]))
.sum();
let sum_left: f64 = data.data_time_interval_indexes.iter() // Add to gradient[j] when j + 1 == idx_right + 1
.filter(|(idx_left, idx_right)| j + 1 == *idx_left) // Add to hessdiag[j] when j + 1 == idx_right + 1
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1])) if *idx_right < gradient.len() {
.sum(); let j = *idx_right;
gradient[j] += (-lambda[j].exp() + lambda[j]).exp() / denom;
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
hessdiag[j] += a - b;
}
gradient[j] = sum_right - sum_left; // Subtract from gradient[j] when j + 1 == idx_left
} // Add to hessdiag[j] when j + 1 == idx_left
if *idx_left > 0 {
// Compute diagonal of Hessian let j = *idx_left - 1;
let mut hessdiag = DVector::zeros(data.num_intervals() - 1); gradient[j] -= (-lambda[j].exp() + lambda[j]).exp() / denom;
for j in 0..(data.num_intervals() - 1) {
let sum_left: f64 = data.data_time_interval_indexes.iter() let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
.filter(|(idx_left, idx_right)| j + 1 == *idx_left) let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
.map(|(idx_left, idx_right)| { hessdiag[j] += -a - b;
let denom = s[*idx_left] - s[*idx_right + 1]; }
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
-a - b
})
.sum();
let sum_right: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
.map(|(idx_left, idx_right)| {
let denom = s[*idx_left] - s[*idx_right + 1];
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
a - b
})
.sum();
hessdiag[j] = sum_left + sum_right;
} }
// Description in Anderson-Bergman (2017) is slightly misleading // Description in Anderson-Bergman (2017) is slightly misleading
@ -451,7 +440,8 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64)
// And we will move in the direction of the gradient // And we will move in the direction of the gradient
// So there are a few more negative signs here than suggested // So there are a few more negative signs here than suggested
let weights = -hessdiag.clone() / 2.0; let weights = -DVector::from_vec(hessdiag.clone()) / 2.0;
let gradient_over_hessdiag = DVector::from_vec(gradient.par_iter().zip(hessdiag.par_iter()).map(|(g, h)| g / h).collect());
let mut s_new; let mut s_new;
let mut p_new; let mut p_new;
@ -461,7 +451,7 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64)
let mut step_size_exponent: i32 = 0; let mut step_size_exponent: i32 = 0;
loop { loop {
let step_size = 0.5_f64.powi(step_size_exponent); let step_size = 0.5_f64.powi(step_size_exponent);
let lambda_target = -gradient.component_div(&hessdiag) * step_size + DVector::from_vec(lambda.clone()); let lambda_target = -gradient_over_hessdiag.clone() * step_size + DVector::from_vec(lambda.clone());
let lambda_new = monotonic_regression_pava(lambda_target, weights.clone()); let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());