turnbull: Introduce analytical solution to computing Hessian
Makes runtime of computing Hessian negligible!
This commit is contained in:
parent
0a8c77fa2c
commit
f043f7c67d
@ -21,7 +21,7 @@ use std::io;
|
|||||||
|
|
||||||
use clap::{Args, ValueEnum};
|
use clap::{Args, ValueEnum};
|
||||||
use csv::{Reader, StringRecord};
|
use csv::{Reader, StringRecord};
|
||||||
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressIterator, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||||
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
|
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
|
||||||
use prettytable::{Table, format, row};
|
use prettytable::{Table, format, row};
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
@ -217,12 +217,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
// --------------------------------------------------
|
// --------------------------------------------------
|
||||||
// Compute standard errors for survival probabilities
|
// Compute standard errors for survival probabilities
|
||||||
|
|
||||||
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Compute Hessian {pos}/{len}").unwrap());
|
let hessian = compute_hessian(&data, &s);
|
||||||
progress_bar.set_length(data.num_obs() as u64);
|
|
||||||
progress_bar.reset();
|
|
||||||
progress_bar.println("Computing standard errors for survival probabilities");
|
|
||||||
|
|
||||||
let hessian = compute_hessian(&data, progress_bar.clone(), &s);
|
|
||||||
|
|
||||||
let mut survival_prob_se: DVector<f64>;
|
let mut survival_prob_se: DVector<f64>;
|
||||||
|
|
||||||
@ -333,33 +328,50 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_hessian(data: &TurnbullData, progress_bar: ProgressBar, s: &DVector<f64>) -> DMatrix<f64> {
|
fn compute_hessian(data: &TurnbullData, s: &DVector<f64>) -> DMatrix<f64> {
|
||||||
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
|
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
|
||||||
|
|
||||||
for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) {
|
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
|
||||||
let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
|
// Compute 1 / (Σ_j α_{i,j} s_j)
|
||||||
hessian_denominator = hessian_denominator.powi(2);
|
let mut one_over_hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
|
||||||
|
one_over_hessian_denominator = one_over_hessian_denominator.powi(-2);
|
||||||
|
|
||||||
let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case
|
// The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1})
|
||||||
let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian
|
// This is nonzero only when α_{i,h} ≠ α_{i,h+1} AND α_{i,k} ≠ α_{i,k+1}
|
||||||
|
// Since each observation spans a continuous sequence of intervals, this is true only at two each of h and k at the boundaries of the observation
|
||||||
|
// h = last interval not involving the observation, h + 1 = first interval involving the observation, etc.
|
||||||
|
|
||||||
for h in idx_start..idx_end {
|
// if *idx_left > 0 { h1 = idx_left - 1; }
|
||||||
let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 };
|
// if *idx_right < data.num_intervals() - 1 { h2 = *idx_right; }
|
||||||
let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 };
|
|
||||||
|
|
||||||
hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator;
|
if *idx_left > 0 {
|
||||||
|
let h1 = idx_left - 1;
|
||||||
|
|
||||||
for k in idx_start..h {
|
// (h, k) = (h1, h1)
|
||||||
let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 };
|
// numerator is -(0 - 1)(0 - 1) = -1
|
||||||
let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 };
|
hessian[(h1, h1)] -= one_over_hessian_denominator;
|
||||||
|
}
|
||||||
|
|
||||||
let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator;
|
if *idx_right < data.num_intervals() - 1 {
|
||||||
hessian[(h, k)] -= value;
|
let h2 = *idx_right;
|
||||||
hessian[(k, h)] -= value;
|
|
||||||
|
// (h, k) = (h2, h2)
|
||||||
|
// numerator is -(1 - 0)(1 - 0) = -1
|
||||||
|
hessian[(h2, h2)] -= one_over_hessian_denominator;
|
||||||
|
|
||||||
|
if *idx_left > 0 {
|
||||||
|
let h1 = idx_left - 1;
|
||||||
|
|
||||||
|
// (h, k) = (h1, h2)
|
||||||
|
// numerator is -(0 - 1)(1 - 0) = 1
|
||||||
|
hessian[(h1, h2)] += one_over_hessian_denominator;
|
||||||
|
|
||||||
|
// (h, k) = (h2, h1)
|
||||||
|
// numerator is -(1 - 0)(0 - 1) = 1
|
||||||
|
hessian[(h2, h1)] += one_over_hessian_denominator;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
progress_bar.finish();
|
|
||||||
|
|
||||||
return hessian;
|
return hessian;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user