turnbull: Introduce analytical solution to computing Hessian

Makes runtime of computing Hessian negligible!
This commit is contained in:
RunasSudo 2023-10-20 20:47:49 +11:00
parent 0a8c77fa2c
commit f043f7c67d
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 36 additions and 24 deletions

View File

@ -21,7 +21,7 @@ use std::io;
use clap::{Args, ValueEnum};
use csv::{Reader, StringRecord};
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressIterator, ProgressStyle};
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
use prettytable::{Table, format, row};
use serde::{Serialize, Deserialize};
@ -217,12 +217,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
// --------------------------------------------------
// Compute standard errors for survival probabilities
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Compute Hessian {pos}/{len}").unwrap());
progress_bar.set_length(data.num_obs() as u64);
progress_bar.reset();
progress_bar.println("Computing standard errors for survival probabilities");
let hessian = compute_hessian(&data, progress_bar.clone(), &s);
let hessian = compute_hessian(&data, &s);
let mut survival_prob_se: DVector<f64>;
@ -333,33 +328,50 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
return s;
}
fn compute_hessian(data: &TurnbullData, progress_bar: ProgressBar, s: &DVector<f64>) -> DMatrix<f64> {
fn compute_hessian(data: &TurnbullData, s: &DVector<f64>) -> DMatrix<f64> {
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) {
let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
hessian_denominator = hessian_denominator.powi(2);
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
// Compute 1 / (Σ_j α_{i,j} s_j)
let mut one_over_hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
one_over_hessian_denominator = one_over_hessian_denominator.powi(-2);
let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case
let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian
// The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1})
// This is nonzero only when α_{i,h} ≠ α_{i,h+1} AND α_{i,k} ≠ α_{i,k+1}
// Since each observation spans a continuous sequence of intervals, this is true only at two each of h and k at the boundaries of the observation
// h = last interval not involving the observation, h + 1 = first interval involving the observation, etc.
for h in idx_start..idx_end {
let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 };
let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 };
// if *idx_left > 0 { h1 = idx_left - 1; }
// if *idx_right < data.num_intervals() - 1 { h2 = *idx_right; }
if *idx_left > 0 {
let h1 = idx_left - 1;
hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator;
// (h, k) = (h1, h1)
// numerator is -(0 - 1)(0 - 1) = -1
hessian[(h1, h1)] -= one_over_hessian_denominator;
}
if *idx_right < data.num_intervals() - 1 {
let h2 = *idx_right;
for k in idx_start..h {
let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 };
let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 };
// (h, k) = (h2, h2)
// numerator is -(1 - 0)(1 - 0) = -1
hessian[(h2, h2)] -= one_over_hessian_denominator;
if *idx_left > 0 {
let h1 = idx_left - 1;
let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator;
hessian[(h, k)] -= value;
hessian[(k, h)] -= value;
// (h, k) = (h1, h2)
// numerator is -(0 - 1)(1 - 0) = 1
hessian[(h1, h2)] += one_over_hessian_denominator;
// (h, k) = (h2, h1)
// numerator is -(1 - 0)(0 - 1) = 1
hessian[(h2, h1)] += one_over_hessian_denominator;
}
}
}
progress_bar.finish();
return hessian;
}