Merge branch 'turnbull-emicm'
Change to EM-ICM algorithm to fit Turnbull estimator Much more efficient - 19.7x speedup compared with old algorithm!
This commit is contained in:
commit
2880fe866d
203
src/turnbull.rs
203
src/turnbull.rs
@ -27,6 +27,7 @@ use prettytable::{Table, format, row};
|
|||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
|
use crate::pava::monotonic_regression_pava;
|
||||||
use crate::term::UnconditionalTermLike;
|
use crate::term::UnconditionalTermLike;
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
@ -205,7 +206,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap());
|
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap());
|
||||||
progress_bar.set_length(u64::MAX);
|
progress_bar.set_length(u64::MAX);
|
||||||
progress_bar.reset();
|
progress_bar.reset();
|
||||||
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
|
progress_bar.println("Running EM-ICM algorithm to fit Turnbull estimator");
|
||||||
|
|
||||||
let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s);
|
let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s);
|
||||||
|
|
||||||
@ -285,24 +286,38 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
|||||||
return intervals;
|
return intervals;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut s: Vec<f64>) -> (Vec<f64>, f64) {
|
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>) -> (Vec<f64>, f64) {
|
||||||
// Get likelihood for each observation (denominator of μ_ij)
|
// Pre-compute S, the survival probability at the start of each interval
|
||||||
let mut likelihood_obs = get_likelihood_obs(data, &s);
|
let mut s = p_to_s(&p);
|
||||||
|
|
||||||
|
// Get likelihood for each observation
|
||||||
|
let likelihood_obs = get_likelihood_obs(data, &s);
|
||||||
let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum();
|
let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum();
|
||||||
|
|
||||||
let mut iteration = 1;
|
let mut iteration = 1;
|
||||||
loop {
|
loop {
|
||||||
// Compute π_j to update s
|
// -------
|
||||||
let pi = compute_pi(data, &s, likelihood_obs);
|
// EM step
|
||||||
|
|
||||||
let likelihood_obs_new = get_likelihood_obs(data, &pi);
|
let p_after_em = do_em_step(data, &p, &s);
|
||||||
let ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
|
let s_after_em = p_to_s(&p_after_em);
|
||||||
|
|
||||||
|
let likelihood_obs_after_em = get_likelihood_obs(data, &s_after_em);
|
||||||
|
let ll_model_after_em: f64 = likelihood_obs_after_em.iter().map(|l| l.ln()).sum();
|
||||||
|
|
||||||
|
p = p_after_em;
|
||||||
|
s = s_after_em;
|
||||||
|
|
||||||
|
// --------
|
||||||
|
// ICM step
|
||||||
|
|
||||||
|
let (p_new, s_new, ll_model_new) = do_icm_step(data, &p, &s, ll_model_after_em);
|
||||||
|
|
||||||
let ll_change = ll_model_new - ll_model;
|
let ll_change = ll_model_new - ll_model;
|
||||||
let converged = ll_change <= ll_tolerance;
|
let converged = ll_change <= ll_tolerance;
|
||||||
|
|
||||||
s = pi;
|
p = p_new;
|
||||||
likelihood_obs = likelihood_obs_new;
|
s = s_new;
|
||||||
ll_model = ll_model_new;
|
ll_model = ll_model_new;
|
||||||
|
|
||||||
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
|
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
|
||||||
@ -324,48 +339,150 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (s, ll_model);
|
return (p, ll_model);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn p_to_s(p: &Vec<f64>) -> Vec<f64> {
|
||||||
|
let mut s = Vec::with_capacity(p.len() + 1); // Survival probabilities
|
||||||
|
let mut survival = 1.0;
|
||||||
|
s.push(1.0);
|
||||||
|
for p_j in p.iter() {
|
||||||
|
survival -= p_j;
|
||||||
|
s.push(survival);
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn s_to_lambda(s: &Vec<f64>) -> Vec<f64> {
|
||||||
|
// S = 1 means Λ = -inf and S = 0 means Λ = inf so skip these
|
||||||
|
let mut lambda = Vec::with_capacity(s.len() - 2); // Cumulative hazard
|
||||||
|
for s_j in &s[1..(s.len() - 1)] {
|
||||||
|
lambda.push((-s_j.ln()).ln());
|
||||||
|
}
|
||||||
|
return lambda;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
|
fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
|
||||||
return data.data_time_interval_indexes
|
return data.data_time_interval_indexes
|
||||||
.par_iter()
|
.par_iter()
|
||||||
.map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum())
|
.map(|(idx_left, idx_right)| s[*idx_left] - s[*idx_right + 1])
|
||||||
.collect();
|
.collect(); // TODO: Return iterator directly
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_pi(data: &TurnbullData, s: &Vec<f64>, likelihood_obs: Vec<f64>) -> Vec<f64> {
|
fn do_em_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>) -> Vec<f64> {
|
||||||
/*
|
// Compute contributions to m
|
||||||
let mut pi: Vec<f64> = vec![0.0; data.num_intervals()];
|
let mut m_contrib = vec![0.0; data.num_intervals()];
|
||||||
for ((idx_left, idx_right), likelihood_obs_i) in data.data_time_interval_indexes.iter().zip(likelihood_obs.iter()) {
|
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
|
||||||
for j in *idx_left..(*idx_right + 1) {
|
let contrib = 1.0 / (s[*idx_left] - s[*idx_right + 1]);
|
||||||
pi[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
let pi = data.data_time_interval_indexes.par_iter().zip(likelihood_obs.par_iter())
|
// Adds to m for the first interval in the observation
|
||||||
.fold_with(
|
m_contrib[*idx_left] += contrib;
|
||||||
// Compute the contributions to pi[j] for each observation and sum them in parallel using fold_with
|
|
||||||
vec![0.0; data.num_intervals()],
|
|
||||||
|mut acc, ((idx_left, idx_right), likelihood_obs_i)| {
|
|
||||||
// Contributions to pi[j] for the i-th observation
|
|
||||||
for j in *idx_left..(*idx_right + 1) {
|
|
||||||
acc[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
|
|
||||||
}
|
|
||||||
acc
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.reduce(
|
|
||||||
// Reduce all the sub-sums from fold_with into the total sum
|
|
||||||
|| vec![0.0; data.num_intervals()],
|
|
||||||
|mut acc, subsum| {
|
|
||||||
acc.iter_mut().zip(subsum.iter()).for_each(|(x, y)| *x += y);
|
|
||||||
acc
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
return pi;
|
// Subtracts from m for the first interval beyond the observation
|
||||||
|
if *idx_right + 1 < data.num_intervals() {
|
||||||
|
m_contrib[*idx_right + 1] -= contrib;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute m
|
||||||
|
let mut m = Vec::with_capacity(data.num_intervals());
|
||||||
|
let mut m_last = 0.0;
|
||||||
|
for m_contrib_j in m_contrib {
|
||||||
|
let m_next = m_last + m_contrib_j / (data.num_obs() as f64);
|
||||||
|
m.push(m_next);
|
||||||
|
m_last = m_next;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update p
|
||||||
|
// p := p * m
|
||||||
|
let p_new = p.par_iter().zip(m.into_par_iter()).map(|(p_j, m_j)| p_j * m_j).collect();
|
||||||
|
|
||||||
|
return p_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_icm_step(data: &TurnbullData, _p: &Vec<f64>, s: &Vec<f64>, ll_model: f64) -> (Vec<f64>, Vec<f64>, f64) {
|
||||||
|
// Compute Λ, the cumulative hazard
|
||||||
|
// Since Λ = -inf when survival is 1, and Λ = inf when survival is 0, these are omitted
|
||||||
|
// The entry at lambda[j] corresponds to the survival immediately before time point j + 1
|
||||||
|
let lambda = s_to_lambda(&s);
|
||||||
|
|
||||||
|
// Compute gradient and diagonal of Hessian
|
||||||
|
let mut gradient = vec![0.0; data.num_intervals() - 1];
|
||||||
|
let mut hessdiag = vec![0.0; data.num_intervals() - 1];
|
||||||
|
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
|
||||||
|
let denom = s[*idx_left] - s[*idx_right + 1];
|
||||||
|
|
||||||
|
// Add to gradient[j] when j + 1 == idx_right + 1
|
||||||
|
// Add to hessdiag[j] when j + 1 == idx_right + 1
|
||||||
|
if *idx_right < gradient.len() {
|
||||||
|
let j = *idx_right;
|
||||||
|
gradient[j] += (-lambda[j].exp() + lambda[j]).exp() / denom;
|
||||||
|
|
||||||
|
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
|
||||||
|
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
|
||||||
|
hessdiag[j] += a - b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subtract from gradient[j] when j + 1 == idx_left
|
||||||
|
// Add to hessdiag[j] when j + 1 == idx_left
|
||||||
|
if *idx_left > 0 {
|
||||||
|
let j = *idx_left - 1;
|
||||||
|
gradient[j] -= (-lambda[j].exp() + lambda[j]).exp() / denom;
|
||||||
|
|
||||||
|
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
|
||||||
|
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
|
||||||
|
hessdiag[j] += -a - b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description in Anderson-Bergman (2017) is slightly misleading
|
||||||
|
// Since we are maximising the likelihood, the second derivatives will be negative
|
||||||
|
// And we will move in the direction of the gradient
|
||||||
|
// So there are a few more negative signs here than suggested
|
||||||
|
|
||||||
|
let weights = -DVector::from_vec(hessdiag.clone()) / 2.0;
|
||||||
|
let gradient_over_hessdiag = DVector::from_vec(gradient.par_iter().zip(hessdiag.par_iter()).map(|(g, h)| g / h).collect());
|
||||||
|
|
||||||
|
let mut s_new;
|
||||||
|
let mut p_new;
|
||||||
|
let mut ll_model_new: f64;
|
||||||
|
|
||||||
|
// Take as large a step as possible while the log-likelihood increases
|
||||||
|
let mut step_size_exponent: i32 = 0;
|
||||||
|
loop {
|
||||||
|
let step_size = 0.5_f64.powi(step_size_exponent);
|
||||||
|
let lambda_target = -gradient_over_hessdiag.clone() * step_size + DVector::from_vec(lambda.clone());
|
||||||
|
|
||||||
|
let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());
|
||||||
|
|
||||||
|
// Convert Λ to S to p
|
||||||
|
s_new = Vec::with_capacity(data.num_intervals() + 1);
|
||||||
|
p_new = Vec::with_capacity(data.num_intervals());
|
||||||
|
|
||||||
|
let mut survival = 1.0;
|
||||||
|
s_new.push(1.0);
|
||||||
|
for lambda_j in lambda_new.iter() {
|
||||||
|
let next_survival = (-lambda_j.exp()).exp();
|
||||||
|
s_new.push(next_survival);
|
||||||
|
p_new.push(survival - next_survival);
|
||||||
|
survival = next_survival;
|
||||||
|
}
|
||||||
|
s_new.push(0.0);
|
||||||
|
p_new.push(survival);
|
||||||
|
|
||||||
|
let likelihood_obs_new = get_likelihood_obs(data, &s_new);
|
||||||
|
ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
|
||||||
|
|
||||||
|
if ll_model_new > ll_model {
|
||||||
|
return (p_new, s_new, ll_model_new);
|
||||||
|
}
|
||||||
|
|
||||||
|
step_size_exponent += 1;
|
||||||
|
|
||||||
|
if step_size_exponent > 10 {
|
||||||
|
panic!("ICM fails to increase log-likelihood");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
|
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
|
||||||
|
Loading…
Reference in New Issue
Block a user