turnbull: Pre-compute survival probabilities

This commit is contained in:
RunasSudo 2023-10-28 23:08:03 +11:00
parent 250cfd8798
commit 81b0b3f9b5
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A

View File

@ -287,8 +287,11 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
} }
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>) -> (Vec<f64>, f64) { fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>) -> (Vec<f64>, f64) {
// Pre-compute S, the survival probability at the start of each interval
let mut s = p_to_s(&p);
// Get likelihood for each observation // Get likelihood for each observation
let mut likelihood_obs = get_likelihood_obs(data, &p); let mut likelihood_obs = get_likelihood_obs(data, &s);
let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum(); let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum();
let mut iteration = 1; let mut iteration = 1;
@ -296,55 +299,41 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
// ------- // -------
// EM step // EM step
// Pre-compute S, the survival probability at the start of each interval
let mut s = Vec::with_capacity(data.num_intervals() + 1);
let mut survival = 1.0;
s.push(1.0);
for p_j in p.iter() {
survival -= p_j;
s.push(survival);
}
// Update p // Update p
let mut p_new = Vec::with_capacity(data.num_intervals()); let mut p_new = Vec::with_capacity(data.num_intervals());
for j in 0..data.num_intervals() { for j in 0..data.num_intervals() {
let tmp: f64 = data.data_time_interval_indexes.iter() let tmp: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j >= *idx_left && j <= *idx_right) .filter(|(idx_left, idx_right)| j >= *idx_left && j <= *idx_right)
//.map(|(idx_left, idx_right)| 1.0 / p[*idx_left..(*idx_right + 1)].iter().sum::<f64>())
.map(|(idx_left, idx_right)| 1.0 / (s[*idx_left] - s[*idx_right + 1])) .map(|(idx_left, idx_right)| 1.0 / (s[*idx_left] - s[*idx_right + 1]))
.sum(); .sum();
p_new.push(p[j] * tmp / (data.num_obs() as f64)); p_new.push(p[j] * tmp / (data.num_obs() as f64));
} }
let likelihood_obs_after_em = get_likelihood_obs(data, &p_new); let mut s_new = p_to_s(&p_new);
let likelihood_obs_after_em = get_likelihood_obs(data, &s_new);
let ll_model_after_em: f64 = likelihood_obs_after_em.iter().map(|l| l.ln()).sum(); let ll_model_after_em: f64 = likelihood_obs_after_em.iter().map(|l| l.ln()).sum();
p = p_new; p = p_new;
s = s_new;
// -------- // --------
// ICM step // ICM step
// Compute Λ // Compute Λ, the cumulative hazard
// S = 1 means Λ = -inf and S = 0 means Λ = inf so skip these let lambda = s_to_lambda(&s);
let mut lambda = Vec::with_capacity(data.num_intervals() - 1);
let mut survival: f64 = 1.0;
for j in 0..(data.num_intervals() - 1) {
survival -= p[j];
lambda.push((-survival.ln()).ln());
}
// Compute gradient // Compute gradient
let mut gradient = DVector::zeros(data.num_intervals() - 1); let mut gradient = DVector::zeros(data.num_intervals() - 1);
for j in 0..(data.num_intervals() - 1) { for j in 0..(data.num_intervals() - 1) {
let sum_right: f64 = data.data_time_interval_indexes.iter() let sum_right: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1) .filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / p[*idx_left..(*idx_right + 1)].iter().sum::<f64>()) .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1]))
.sum(); .sum();
let sum_left: f64 = data.data_time_interval_indexes.iter() let sum_left: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_left) .filter(|(idx_left, idx_right)| j + 1 == *idx_left)
.map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / p[*idx_left..(*idx_right + 1)].iter().sum::<f64>()) .map(|(idx_left, idx_right)| (-lambda[j].exp() + lambda[j]).exp() / (s[*idx_left] - s[*idx_right + 1]))
.sum(); .sum();
gradient[j] = sum_right - sum_left; gradient[j] = sum_right - sum_left;
@ -356,7 +345,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
let sum_left: f64 = data.data_time_interval_indexes.iter() let sum_left: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_left) .filter(|(idx_left, idx_right)| j + 1 == *idx_left)
.map(|(idx_left, idx_right)| { .map(|(idx_left, idx_right)| {
let denom: f64 = p[*idx_left..(*idx_right + 1)].iter().sum(); let denom = s[*idx_left] - s[*idx_right + 1];
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
-a - b -a - b
@ -366,7 +355,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
let sum_right: f64 = data.data_time_interval_indexes.iter() let sum_right: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1) .filter(|(idx_left, idx_right)| j + 1 == *idx_right + 1)
.map(|(idx_left, idx_right)| { .map(|(idx_left, idx_right)| {
let denom: f64 = p[*idx_left..(*idx_right + 1)].iter().sum(); let denom = s[*idx_left] - s[*idx_right + 1];
let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom; let a = ((lambda[j] - lambda[j].exp()).exp() * (1.0 - lambda[j].exp())) / denom;
let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2); let b = (2.0 * lambda[j] - 2.0 * lambda[j].exp()).exp() / denom.powi(2);
a - b a - b
@ -395,17 +384,21 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
let lambda_new = monotonic_regression_pava(lambda_target, weights.clone()); let lambda_new = monotonic_regression_pava(lambda_target, weights.clone());
// Convert Λ to S to p // Convert Λ to S to p
s_new = Vec::with_capacity(data.num_intervals() + 1);
p_new = Vec::with_capacity(data.num_intervals()); p_new = Vec::with_capacity(data.num_intervals());
let mut survival = 1.0; let mut survival = 1.0;
s_new.push(1.0);
for lambda_j in lambda_new.iter() { for lambda_j in lambda_new.iter() {
let next_survival = (-lambda_j.exp()).exp(); let next_survival = (-lambda_j.exp()).exp();
s_new.push(next_survival);
p_new.push(survival - next_survival); p_new.push(survival - next_survival);
survival = next_survival; survival = next_survival;
} }
s_new.push(0.0);
p_new.push(survival); p_new.push(survival);
let likelihood_obs_new = get_likelihood_obs(data, &p_new); let likelihood_obs_new = get_likelihood_obs(data, &s_new);
ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum(); ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
if ll_model_new > ll_model_after_em { if ll_model_new > ll_model_after_em {
@ -423,7 +416,7 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
let converged = ll_change <= ll_tolerance; let converged = ll_change <= ll_tolerance;
p = p_new; p = p_new;
//likelihood_obs = likelihood_obs_new; s = s_new;
ll_model = ll_model_new; ll_model = ll_model_new;
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations // Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
@ -448,11 +441,31 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
return (p, ll_model); return (p, ll_model);
} }
fn p_to_s(p: &Vec<f64>) -> Vec<f64> {
let mut s = Vec::with_capacity(p.len() + 1); // Survival probabilities
let mut survival = 1.0;
s.push(1.0);
for p_j in p.iter() {
survival -= p_j;
s.push(survival);
}
return s;
}
fn s_to_lambda(s: &Vec<f64>) -> Vec<f64> {
// S = 1 means Λ = -inf and S = 0 means Λ = inf so skip these
let mut lambda = Vec::with_capacity(s.len() - 2); // Cumulative hazard
for s_j in &s[1..(s.len() - 1)] {
lambda.push((-s_j.ln()).ln());
}
return lambda;
}
fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> { fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
return data.data_time_interval_indexes return data.data_time_interval_indexes
.par_iter() .par_iter()
.map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum()) .map(|(idx_left, idx_right)| s[*idx_left] - s[*idx_right + 1])
.collect(); .collect(); // TODO: Return iterator directly
} }
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> { fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {