turnbull: Improve efficiency of ICM step as suggested by Anderson-Bergman

This commit is contained in:
RunasSudo 2023-10-28 23:47:00 +11:00
parent 85e3ee0dcd
commit e726fad99b
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A

View File

@ -402,48 +402,37 @@ fn do_em_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>) -> Vec<f64> {
fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) {
// Compute Λ, the cumulative hazard
// Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted
// The entry at lambda[j] corresponds to the survival immediately before time point j + 1
let lambda = s_to_lambda(&s);
// Compute gradient
let mut gradient = DVector::zeros(data.num_intervals() - 1);
for j in 0..(data.num_intervals() - 1) {
let sum_right: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1]))
.sum();
// Compute gradient and diagonal of Hessian
let mut gradient = vec![0.0; data.num_intervals() - 1];
let mut hessdiag = vec![0.0; data.num_intervals() - 1];
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
let denom = s[*idx_left] - s[*idx_right + 1];
let sum_left: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_left)
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1]))
.sum();
// Add to gradient[j] when j + 1 == idx_right + 1
// Add to hessdiag[j] when j + 1 == idx_right + 1
if *idx_right < gradient.len() {
let j = *idx_right;
gradient[j] += (-lambda[j].exp() + lambda[j]).exp() / denom;
gradient[j] = sum_right - sum_left;
}
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
hessdiag[j] += a - b;
}
// Compute diagonal of Hessian
let mut hessdiag = DVector::zeros(data.num_intervals() - 1);
for j in 0..(data.num_intervals() - 1) {
let sum_left: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_left)
.map(|(idx_left, idx_right)| {
let denom = s[*idx_left] - s[*idx_right + 1];
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
-a - b
})
.sum();
// Subtract from gradient[j] when j + 1 == idx_left
// Add to hessdiag[j] when j + 1 == idx_left
if *idx_left > 0 {
let j = *idx_left - 1;
gradient[j] -= (-lambda[j].exp() + lambda[j]).exp() / denom;
let sum_right: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
.map(|(idx_left, idx_right)| {
let denom = s[*idx_left] - s[*idx_right + 1];
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
a - b
})
.sum();
hessdiag[j] = sum_left + sum_right;
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
hessdiag[j] += -a - b;
}
}
// Description in Anderson-Bergman (2017) is slightly misleading
@ -451,7 +440,8 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64)
// And we will move in the direction of the gradient
// So there are a few more negative signs here than suggested
let weights = -hessdiag.clone() / 2.0;
let weights = -DVector::from_vec(hessdiag.clone()) / 2.0;
let gradient_over_hessdiag = DVector::from_vec(gradient.par_iter().zip(hessdiag.par_iter()).map(|(g, h)| g / h).collect());
let mut s_new;
let mut p_new;
@ -461,7 +451,7 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64)
let mut step_size_exponent: i32 = 0;
loop {
let step_size = 0.5_f64.powi(step_size_exponent);
let lambda_target = -gradient.component_div(&hessdiag) * step_size + DVector::from_vec(lambda.clone());
let lambda_target = -gradient_over_hessdiag.clone() * step_size + DVector::from_vec(lambda.clone());
let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());