turnbull: Initial implementation of EM-ICM algorithm

This commit is contained in:
RunasSudo 2023-10-28 22:48:59 +11:00
parent 79c53895b0
commit 250cfd8798
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 132 additions and 45 deletions

View File

@ -27,6 +27,7 @@ use prettytable::{Table, format, row};
use rayon::prelude::*; use rayon::prelude::*;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::pava::monotonic_regression_pava;
use crate::term::UnconditionalTermLike; use crate::term::UnconditionalTermLike;
#[derive(Args)] #[derive(Args)]
@ -205,7 +206,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap()); progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap());
progress_bar.set_length(u64::MAX); progress_bar.set_length(u64::MAX);
progress_bar.reset(); progress_bar.reset();
progress_bar.println("Running iterative algorithm to fit Turnbull estimator"); progress_bar.println("Running EM-ICM algorithm to fit Turnbull estimator");
let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s); let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s);
@ -285,24 +286,144 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
return intervals; return intervals;
} }
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut s: Vec<f64>) -> (Vec<f64>, f64) { fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>) -> (Vec<f64>, f64) {
// Get likelihood for each observation (denominator of μ_ij) // Get likelihood for each observation
let mut likelihood_obs = get_likelihood_obs(data, &s); let mut likelihood_obs = get_likelihood_obs(data, &p);
let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum(); let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum();
let mut iteration = 1; let mut iteration = 1;
loop { loop {
// Compute π_j to update s // -------
let pi = compute_pi(data, &s, likelihood_obs); // EM step
let likelihood_obs_new = get_likelihood_obs(data, &pi); // Pre-compute S, the survival probability at the start of each interval
let ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); let mut s = Vec::with_capacity(data.num_intervals() + 1);
let mut survival = 1.0;
s.push(1.0);
for p_j in p.iter() {
survival -= p_j;
s.push(survival);
}
// Update p
let mut p_new = Vec::with_capacity(data.num_intervals());
for j in 0..data.num_intervals() {
let tmp: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j >= *idx_left && j <= *idx_right)
//.map(|(idx_left, idx_right)| 1.0 / p[*idx_left..(*idx_right + 1)].iter().sum::<f64>())
.map(|(idx_left, idx_right)| 1.0 / (s[*idx_left] - s[*idx_right + 1]))
.sum();
p_new.push(p[j] * tmp / (data.num_obs() as f64));
}
let likelihood_obs_after_em = get_likelihood_obs(data, &p_new);
let ll_model_after_em: f64 = likelihood_obs_after_em.iter().map(|l| l.ln()).sum();
p = p_new;
// --------
// ICM step
// Compute Λ
// S = 1 means Λ = -inf and S = 0 means Λ = inf so skip these
let mut lambda = Vec::with_capacity(data.num_intervals() - 1);
let mut survival: f64 = 1.0;
for j in 0..(data.num_intervals() - 1) {
survival -= p[j];
lambda.push((-survival.ln()).ln());
}
// Compute gradient
let mut gradient = DVector::zeros(data.num_intervals() - 1);
for j in 0..(data.num_intervals() - 1) {
let sum_right: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / p[*idx_left..(*idx_right + 1)].iter().sum::<f64>())
.sum();
let sum_left: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_left)
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / p[*idx_left..(*idx_right + 1)].iter().sum::<f64>())
.sum();
gradient[j] = sum_right - sum_left;
}
// Compute diagonal of Hessian
let mut hessdiag = DVector::zeros(data.num_intervals() - 1);
for j in 0..(data.num_intervals() - 1) {
let sum_left: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_left)
.map(|(idx_left, idx_right)| {
let denom: f64 = p[*idx_left..(*idx_right + 1)].iter().sum();
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
-a - b
})
.sum();
let sum_right: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
.map(|(idx_left, idx_right)| {
let denom: f64 = p[*idx_left..(*idx_right + 1)].iter().sum();
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
a - b
})
.sum();
hessdiag[j] = sum_left + sum_right;
}
// Description in Anderson-Bergman (2017) is slightly misleading
// Since we are maximising the likelihood, the second derivatives will be negative
// And we will move in the direction of the gradient
// So there are a few more negative signs here than suggested
let weights = -hessdiag.clone() / 2.0;
let mut p_new;
let mut ll_model_new: f64;
// Take as large a step as possible while the log-likelihood increases
let mut step_size_exponent: i32 = 0;
loop {
let step_size = 0.5_f64.powi(step_size_exponent);
let lambda_target = -gradient.component_div(&hessdiag) * step_size + DVector::from_vec(lambda.clone());
let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());
// Convert Λ to S to p
p_new = Vec::with_capacity(data.num_intervals());
let mut survival = 1.0;
for lambda_j in lambda_new.iter() {
let next_survival = (-lambda_j.exp()).exp();
p_new.push(survival - next_survival);
survival = next_survival;
}
p_new.push(survival);
let likelihood_obs_new = get_likelihood_obs(data, &p_new);
ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
if ll_model_new > ll_model_after_em {
break;
}
step_size_exponent += 1;
if step_size_exponent > 10 {
panic!("ICM fails to increase log-likelihood");
}
}
let ll_change = ll_model_new - ll_model; let ll_change = ll_model_new - ll_model;
let converged = ll_change <= ll_tolerance; let converged = ll_change <= ll_tolerance;
s = pi; p = p_new;
likelihood_obs = likelihood_obs_new; //likelihood_obs = likelihood_obs_new;
ll_model = ll_model_new; ll_model = ll_model_new;
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations // Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
@ -324,7 +445,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
} }
} }
return (s, ll_model); return (p, ll_model);
} }
fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> { fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
@ -334,40 +455,6 @@ fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
.collect(); .collect();
} }
fn compute_pi(data: &TurnbullData, s: &Vec<f64>, likelihood_obs: Vec<f64>) -> Vec<f64> {
/*
let mut pi: Vec<f64> = vec![0.0; data.num_intervals()];
for ((idx_left, idx_right), likelihood_obs_i) in data.data_time_interval_indexes.iter().zip(likelihood_obs.iter()) {
for j in *idx_left..(*idx_right + 1) {
pi[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
}
}
*/
let pi = data.data_time_interval_indexes.par_iter().zip(likelihood_obs.par_iter())
.fold_with(
// Compute the contributions to pi[j] for each observation and sum them in parallel using fold_with
vec![0.0; data.num_intervals()],
|mut acc, ((idx_left, idx_right), likelihood_obs_i)| {
// Contributions to pi[j] for the i-th observation
for j in *idx_left..(*idx_right + 1) {
acc[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
}
acc
}
)
.reduce(
// Reduce all the sub-sums from fold_with into the total sum
|| vec![0.0; data.num_intervals()],
|mut acc, subsum| {
acc.iter_mut().zip(subsum.iter()).for_each(|(x, y)| *x += y);
acc
}
);
return pi;
}
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> { fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);