diff --git a/src/turnbull.rs b/src/turnbull.rs index ac60a88..936df85 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -21,7 +21,7 @@ use std::io; use clap::{Args, ValueEnum}; use csv::{Reader, StringRecord}; -use indicatif::{ProgressBar, ProgressDrawTarget, ProgressIterator, ProgressStyle}; +use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2}; use prettytable::{Table, format, row}; use serde::{Serialize, Deserialize}; @@ -217,12 +217,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i // -------------------------------------------------- // Compute standard errors for survival probabilities - progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Compute Hessian {pos}/{len}").unwrap()); - progress_bar.set_length(data.num_obs() as u64); - progress_bar.reset(); - progress_bar.println("Computing standard errors for survival probabilities"); - - let hessian = compute_hessian(&data, progress_bar.clone(), &s); + let hessian = compute_hessian(&data, &s); let mut survival_prob_se: DVector; @@ -333,33 +328,50 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma return s; } -fn compute_hessian(data: &TurnbullData, progress_bar: ProgressBar, s: &DVector) -> DMatrix { +fn compute_hessian(data: &TurnbullData, s: &DVector) -> DMatrix { let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); - for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) { - let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum(); - hessian_denominator = hessian_denominator.powi(2); + for (idx_left, idx_right) in data.data_time_interval_indexes.iter() { + // Compute 1 / (Σ_j α_{i,j} s_j) + let mut one_over_hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum(); + one_over_hessian_denominator = one_over_hessian_denominator.powi(-2); - let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case - let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian + // The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1}) + // This is nonzero only when α_{i,h} ≠ α_{i,h+1} AND α_{i,k} ≠ α_{i,k+1} + // Since each observation spans a continuous sequence of intervals, this is true only at two each of h and k at the boundaries of the observation + // h = last interval not involving the observation, h + 1 = first interval involving the observation, etc. - for h in idx_start..idx_end { - let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 }; - let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 }; + // if *idx_left > 0 { h1 = idx_left - 1; } + // if *idx_right < data.num_intervals() - 1 { h2 = *idx_right; } + + if *idx_left > 0 { + let h1 = idx_left - 1; - hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator; + // (h, k) = (h1, h1) + // numerator is -(0 - 1)(0 - 1) = -1 + hessian[(h1, h1)] -= one_over_hessian_denominator; + } + + if *idx_right < data.num_intervals() - 1 { + let h2 = *idx_right; - for k in idx_start..h { - let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 }; - let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 }; + // (h, k) = (h2, h2) + // numerator is -(1 - 0)(1 - 0) = -1 + hessian[(h2, h2)] -= one_over_hessian_denominator; + + if *idx_left > 0 { + let h1 = idx_left - 1; - let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator; - hessian[(h, k)] -= value; - hessian[(k, h)] -= value; + // (h, k) = (h1, h2) + // numerator is -(0 - 1)(1 - 0) = 1 + hessian[(h1, h2)] += one_over_hessian_denominator; + + // (h, k) = (h2, h1) + // numerator is -(1 - 0)(0 - 1) = 1 + hessian[(h2, h1)] += one_over_hessian_denominator; } } } - progress_bar.finish(); return hessian; }