turnbull: Improve efficiency of ICM step as suggested by Anderson-Bergman
This commit is contained in:
parent
85e3ee0dcd
commit
e726fad99b
@ -402,48 +402,37 @@ fn do_em_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>) -> Vec<f64> {
|
|||||||
|
|
||||||
fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) {
|
fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) {
|
||||||
// Compute Λ, the cumulative hazard
|
// Compute Λ, the cumulative hazard
|
||||||
|
// Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted
|
||||||
|
// The entry at lambda[j] corresponds to the survival immediately before time point j + 1
|
||||||
let lambda = s_to_lambda(&s);
|
let lambda = s_to_lambda(&s);
|
||||||
|
|
||||||
// Compute gradient
|
// Compute gradient and diagonal of Hessian
|
||||||
let mut gradient = DVector::zeros(data.num_intervals() - 1);
|
let mut gradient = vec![0.0; data.num_intervals() - 1];
|
||||||
for j in 0..(data.num_intervals() - 1) {
|
let mut hessdiag = vec![0.0; data.num_intervals() - 1];
|
||||||
let sum_right: f64 = data.data_time_interval_indexes.iter()
|
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
|
||||||
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
|
let denom = s[*idx_left] - s[*idx_right + 1];
|
||||||
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1]))
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
let sum_left: f64 = data.data_time_interval_indexes.iter()
|
// Add to gradient[j] when j + 1 == idx_right + 1
|
||||||
.filter(|(idx_left, idx_right)| j + 1 == *idx_left)
|
// Add to hessdiag[j] when j + 1 == idx_right + 1
|
||||||
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1]))
|
if *idx_right < gradient.len() {
|
||||||
.sum();
|
let j = *idx_right;
|
||||||
|
gradient[j] += (-lambda[j].exp() + lambda[j]).exp() / denom;
|
||||||
|
|
||||||
|
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
|
||||||
|
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
|
||||||
|
hessdiag[j] += a - b;
|
||||||
|
}
|
||||||
|
|
||||||
gradient[j] = sum_right - sum_left;
|
// Subtract from gradient[j] when j + 1 == idx_left
|
||||||
}
|
// Add to hessdiag[j] when j + 1 == idx_left
|
||||||
|
if *idx_left > 0 {
|
||||||
// Compute diagonal of Hessian
|
let j = *idx_left - 1;
|
||||||
let mut hessdiag = DVector::zeros(data.num_intervals() - 1);
|
gradient[j] -= (-lambda[j].exp() + lambda[j]).exp() / denom;
|
||||||
for j in 0..(data.num_intervals() - 1) {
|
|
||||||
let sum_left: f64 = data.data_time_interval_indexes.iter()
|
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
|
||||||
.filter(|(idx_left, idx_right)| j + 1 == *idx_left)
|
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
|
||||||
.map(|(idx_left, idx_right)| {
|
hessdiag[j] += -a - b;
|
||||||
let denom = s[*idx_left] - s[*idx_right + 1];
|
}
|
||||||
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
|
|
||||||
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
|
|
||||||
-a - b
|
|
||||||
})
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
let sum_right: f64 = data.data_time_interval_indexes.iter()
|
|
||||||
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
|
|
||||||
.map(|(idx_left, idx_right)| {
|
|
||||||
let denom = s[*idx_left] - s[*idx_right + 1];
|
|
||||||
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
|
|
||||||
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
|
|
||||||
a - b
|
|
||||||
})
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
hessdiag[j] = sum_left + sum_right;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Description in Anderson-Bergman (2017) is slightly misleading
|
// Description in Anderson-Bergman (2017) is slightly misleading
|
||||||
@ -451,7 +440,8 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64)
|
|||||||
// And we will move in the direction of the gradient
|
// And we will move in the direction of the gradient
|
||||||
// So there are a few more negative signs here than suggested
|
// So there are a few more negative signs here than suggested
|
||||||
|
|
||||||
let weights = -hessdiag.clone() / 2.0;
|
let weights = -DVector::from_vec(hessdiag.clone()) / 2.0;
|
||||||
|
let gradient_over_hessdiag = DVector::from_vec(gradient.par_iter().zip(hessdiag.par_iter()).map(|(g, h)| g / h).collect());
|
||||||
|
|
||||||
let mut s_new;
|
let mut s_new;
|
||||||
let mut p_new;
|
let mut p_new;
|
||||||
@ -461,7 +451,7 @@ fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64)
|
|||||||
let mut step_size_exponent: i32 = 0;
|
let mut step_size_exponent: i32 = 0;
|
||||||
loop {
|
loop {
|
||||||
let step_size = 0.5_f64.powi(step_size_exponent);
|
let step_size = 0.5_f64.powi(step_size_exponent);
|
||||||
let lambda_target = -gradient.component_div(&hessdiag) * step_size + DVector::from_vec(lambda.clone());
|
let lambda_target = -gradient_over_hessdiag.clone() * step_size + DVector::from_vec(lambda.clone());
|
||||||
|
|
||||||
let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());
|
let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user