2023-04-17 17:50:43 +10:00
// hpstat: High-performance statistics implementations
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
const Z_97_5 : f64 = 1.959964 ; // This is the limit of resolution for an f64
use std ::io ;
use clap ::{ Args , ValueEnum } ;
2023-04-21 17:39:24 +10:00
use csv ::{ Reader , StringRecord } ;
2023-04-21 17:21:33 +10:00
use indicatif ::{ ParallelProgressIterator , ProgressBar , ProgressDrawTarget , ProgressStyle } ;
2023-04-17 17:50:43 +10:00
use nalgebra ::{ DMatrix , DVector , Matrix1xX } ;
use prettytable ::{ Table , format , row } ;
use rayon ::prelude ::* ;
use serde ::{ Serialize , Deserialize } ;
2023-04-21 17:21:33 +10:00
use crate ::term ::UnconditionalTermLike ;
2023-04-17 17:50:43 +10:00
#[ derive(Args) ]
pub struct IntCoxArgs {
/// Path to CSV input file containing the observations
#[ arg() ]
input : String ,
/// Output format
#[ arg(long, value_enum, default_value= " text " ) ]
output : OutputFormat ,
/// Maximum number of E-M iterations to attempt
#[ arg(long, default_value= " 100 " ) ]
max_iterations : u32 ,
/// Terminate E-M algorithm when the maximum absolute change in all parameters is less than this tolerance
#[ arg(long, default_value= " 0.001 " ) ]
tolerance : f64 ,
/// Estimate baseline hazard function using Turnbull innermost intervals
#[ arg(long) ]
reduced : bool ,
}
#[ derive(ValueEnum, Clone) ]
enum OutputFormat {
Text ,
Json
}
pub fn main ( args : IntCoxArgs ) {
// Read data
2023-04-21 17:39:24 +10:00
let ( indep_names , data_times , data_indep ) = read_data ( & args . input ) ;
2023-04-17 17:50:43 +10:00
// Fit regression
2023-04-21 17:21:33 +10:00
let progress_bar = ProgressBar ::with_draw_target ( Some ( 0 ) , ProgressDrawTarget ::term_like ( Box ::new ( UnconditionalTermLike ::stderr ( ) ) ) ) ;
2023-04-17 17:50:43 +10:00
let result = fit_interval_censored_cox ( data_times , data_indep , args . max_iterations , args . tolerance , args . reduced , progress_bar ) ;
// Display output
match args . output {
OutputFormat ::Text = > {
println! ( ) ;
println! ( ) ;
println! ( " LL-Model = {:.5} " , result . ll_model ) ;
println! ( " LL-Null = {:.5} " , result . ll_null ) ;
let mut summary = Table ::new ( ) ;
let format = format ::FormatBuilder ::new ( )
. separators ( & [ format ::LinePosition ::Top , format ::LinePosition ::Title , format ::LinePosition ::Bottom ] , format ::LineSeparator ::new ( '-' , '+' , '+' , '+' ) )
. padding ( 2 , 2 )
. build ( ) ;
summary . set_format ( format ) ;
summary . set_titles ( row! [ " Parameter " , c ->" β " , c ->" Std Err. " , c ->" exp(β) " , H2c ->" (95% CI) " ] ) ;
for ( i , indep_name ) in indep_names . iter ( ) . enumerate ( ) {
summary . add_row ( row! [
indep_name ,
r ->format ! ( " {:.5} " , result . params [ i ] ) ,
r ->format ! ( " {:.5} " , result . params_se [ i ] ) ,
r ->format ! ( " {:.5} " , result . params [ i ] . exp ( ) ) ,
r ->format ! ( " ({:.5}, " , ( result . params [ i ] - Z_97_5 * result . params_se [ i ] ) . exp ( ) ) ,
format! ( " {:.5} ) " , ( result . params [ i ] + Z_97_5 * result . params_se [ i ] ) . exp ( ) ) ,
] ) ;
}
summary . printstd ( ) ;
}
OutputFormat ::Json = > {
println! ( " {} " , serde_json ::to_string ( & result ) . unwrap ( ) ) ;
}
}
}
2023-04-21 17:39:24 +10:00
pub fn read_data ( path : & str ) -> ( Vec < String > , DMatrix < f64 > , DMatrix < f64 > ) {
// Read CSV into memory
let headers : StringRecord ;
let records : Vec < StringRecord > ;
if path = = " - " {
let mut csv_reader = Reader ::from_reader ( io ::stdin ( ) ) ;
headers = csv_reader . headers ( ) . unwrap ( ) . clone ( ) ;
records = csv_reader . records ( ) . map ( | r | r . unwrap ( ) ) . collect ( ) ;
} else {
let mut csv_reader = Reader ::from_path ( path ) . unwrap ( ) ;
headers = csv_reader . headers ( ) . unwrap ( ) . clone ( ) ;
records = csv_reader . records ( ) . map ( | r | r . unwrap ( ) ) . collect ( ) ;
}
// Read data into matrices
let mut data_times : DMatrix < f64 > = DMatrix ::zeros (
2 , // Left time, right time
records . len ( )
) ;
// Called "Z" in the paper and "X" in the C++ code
let mut data_indep : DMatrix < f64 > = DMatrix ::zeros (
headers . len ( ) - 2 ,
records . len ( )
) ;
// Parse header row
let indep_names : Vec < String > = headers . iter ( ) . skip ( 2 ) . map ( String ::from ) . collect ( ) ;
// Parse data
for ( i , row ) in records . iter ( ) . enumerate ( ) {
for ( j , item ) in row . iter ( ) . enumerate ( ) {
let value = match item {
" inf " = > f64 ::INFINITY ,
_ = > item . parse ( ) . expect ( " Malformed float " )
} ;
if j < 2 {
data_times [ ( j , i ) ] = value ;
} else {
data_indep [ ( j - 2 , i ) ] = value ;
}
}
}
return ( indep_names , data_times , data_indep ) ;
}
2023-04-17 17:50:43 +10:00
struct IntervalCensoredCoxData {
data_times : DMatrix < f64 > ,
data_indep : DMatrix < f64 > ,
// Cached intermediate values
time_points : Vec < f64 > ,
r_star_indicator : DMatrix < f64 > ,
z_z_transpose : Vec < DMatrix < f64 > > ,
}
impl IntervalCensoredCoxData {
fn num_obs ( & self ) -> usize {
return self . data_indep . ncols ( ) ;
}
fn num_covs ( & self ) -> usize {
return self . data_indep . nrows ( ) ;
}
fn num_times ( & self ) -> usize {
return self . time_points . len ( ) ;
}
}
2023-04-17 22:12:07 +10:00
pub fn fit_interval_censored_cox ( data_times : DMatrix < f64 > , mut data_indep : DMatrix < f64 > , max_iterations : u32 , tolerance : f64 , reduced : bool , progress_bar : ProgressBar ) -> IntervalCensoredCoxResult {
2023-04-17 17:50:43 +10:00
// ----------------------
// Prepare for regression
// Standardise values
let indep_means = data_indep . column_mean ( ) ;
let indep_stdev = data_indep . column_variance ( ) . apply_into ( | x | { * x = ( * x * data_indep . ncols ( ) as f64 / ( data_indep . ncols ( ) - 1 ) as f64 ) . sqrt ( ) ; } ) ;
for j in 0 .. data_indep . nrows ( ) {
data_indep . row_mut ( j ) . apply ( | x | * x = ( * x - indep_means [ j ] ) / indep_stdev [ j ] ) ;
}
// Get time points (t_0 = 0, t_1, ..., t_m)
let mut time_points : Vec < f64 > ;
if reduced {
// Turnbull intervals
let mut all_time_points : Vec < ( f64 , bool ) > = Vec ::new ( ) ; // Vec of (time, is_left)
all_time_points . extend ( data_times . row ( 0 ) . iter ( ) . map ( | t | ( * t , true ) ) ) ;
all_time_points . extend ( data_times . row ( 1 ) . iter ( ) . map ( | t | ( * t , false ) ) ) ;
all_time_points . sort_by ( | ( t1 , _ ) , ( t2 , _ ) | t1 . partial_cmp ( t2 ) . unwrap ( ) ) ;
time_points = Vec ::new ( ) ;
for i in 1 .. all_time_points . len ( ) {
if all_time_points [ i - 1 ] . 1 = = true & & all_time_points [ i ] . 1 = = false {
time_points . push ( all_time_points [ i - 1 ] . 0 ) ;
time_points . push ( all_time_points [ i ] . 0 ) ;
}
}
time_points . push ( 0.0 ) ; // Ensure 0 is in the list
time_points . retain ( | t | t . is_finite ( ) ) ; // Remove infinity
time_points . sort_by ( | a , b | a . partial_cmp ( b ) . unwrap ( ) ) ; // Cannot use .sort() as f64 does not implement Ord
time_points . dedup ( ) ;
} else {
// All observed intervals
time_points = data_times . iter ( ) . copied ( ) . collect ( ) ;
time_points . push ( 0.0 ) ; // Ensure 0 is in the list
time_points . retain ( | t | t . is_finite ( ) ) ; // Remove infinity
time_points . sort_by ( | a , b | a . partial_cmp ( b ) . unwrap ( ) ) ; // Cannot use .sort() as f64 does not implement Ord
time_points . dedup ( ) ;
}
// Initialise β, λ
let mut beta = DVector ::zeros ( data_indep . nrows ( ) ) ;
let mut lambda = DVector ::repeat ( time_points . len ( ) , 1.0 / ( time_points . len ( ) - 1 ) as f64 ) ;
// Compute I(t_k <= R*_i)
// Where R*_i is R_i if R_i ≠ ∞, otherwise it is L_i
let mut r_star_indicator = DMatrix ::zeros ( data_indep . ncols ( ) , time_points . len ( ) ) ;
for ( i , observation ) in data_times . column_iter ( ) . enumerate ( ) {
let time_right_star = if observation [ 1 ] . is_finite ( ) { observation [ 1 ] } else { observation [ 0 ] } ;
for ( k , time ) in time_points . iter ( ) . enumerate ( ) {
if * time < = time_right_star {
// t_k <= R*_i
r_star_indicator [ ( i , k ) ] = 1.0 ;
} else {
r_star_indicator [ ( i , k ) ] = 0.0 ;
}
}
}
// Pre-compute Z * Z^T
// Indexed by observation -> Matrix (num-covariates, num-covariates)
let mut z_z_transpose : Vec < DMatrix < f64 > > = Vec ::new ( ) ;
for i in 0 .. data_indep . ncols ( ) {
let covariates = data_indep . column ( i ) ;
z_z_transpose . push ( covariates * covariates . transpose ( ) ) ;
}
let data = IntervalCensoredCoxData {
data_times : data_times ,
data_indep : data_indep ,
time_points : time_points ,
r_star_indicator : r_star_indicator ,
z_z_transpose : z_z_transpose ,
} ;
// -------------------
// Apply E-M algorithm
progress_bar . set_length ( u64 ::MAX ) ;
progress_bar . reset ( ) ;
progress_bar . set_style ( ProgressStyle ::with_template ( " [{elapsed_precise}] {bar:40} {msg} " ) . unwrap ( ) ) ;
progress_bar . println ( " Running E-M algorithm to fit interval-censored Cox model " ) ;
let mut iteration : u32 = 0 ;
loop {
// Pre-compute exp(β^T * Z_ik)
let exp_beta_z : Matrix1xX < f64 > = ( beta . transpose ( ) * & data . data_indep ) . apply_into ( | x | { * x = x . exp ( ) ; } ) ;
// Do E-step
let posterior_weight = do_e_step ( & data , & exp_beta_z , & lambda ) ;
// Do M-step
let ( new_beta , new_lambda ) = do_m_step ( & data , & exp_beta_z , & beta , posterior_weight ) ;
// Check for convergence
let ( coef_change , converged ) = em_check_convergence ( & beta , & lambda , & new_beta , & new_lambda , tolerance ) ;
beta = new_beta ;
lambda = new_lambda ;
// Update progress bar
// Estimate progress according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations
let progress1 = ( ( - coef_change . log10 ( ) ) . max ( 0.0 ) / - tolerance . log10 ( ) * u64 ::MAX as f64 ) as u64 ;
let progress2 = ( iteration as f64 / max_iterations as f64 * u64 ::MAX as f64 ) as u64 ;
progress_bar . set_position ( progress_bar . position ( ) . max ( progress1 ) . max ( progress2 ) ) ;
progress_bar . set_message ( format! ( " Iter {} (delta = {:.6} ) " , iteration + 1 , coef_change ) ) ;
if converged {
progress_bar . finish ( ) ;
break ;
}
iteration + = 1 ;
if iteration > = max_iterations {
panic! ( " Exceeded --max-iterations " ) ;
}
}
// Compute log-likelihood
let ll_model = log_likelihood_obs ( & data , & beta , & lambda ) . sum ( ) ;
// Unstandardise betas
let mut beta_unstandardised : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
for ( j , beta_value ) in beta . iter ( ) . enumerate ( ) {
beta_unstandardised [ j ] = beta_value / indep_stdev [ j ] ;
}
// -------------------------
// Compute covariance matrix
// Compute profile log-likelihoods
let h = 5.0 / ( data . num_obs ( ) as f64 ) . sqrt ( ) ; // "a constant of order n^(-1/2)"
progress_bar . set_length ( data . num_covs ( ) as u64 + 2 ) ;
progress_bar . reset ( ) ;
progress_bar . set_style ( ProgressStyle ::with_template ( " [{elapsed_precise}] {bar:40} Profile LL {pos}/{len} " ) . unwrap ( ) ) ;
progress_bar . println ( " Profiling log-likelihood to compute covariance matrix " ) ;
// ll_null = log-likelihood for null model
// pll_toggle_zero = log-likelihoods for each observation at final beta
// pll_toggle_one = log-likelihoods for each observation at toggled beta
let ll_null = profile_log_likelihood_obs ( & data , DVector ::zeros ( data . num_covs ( ) ) , lambda . clone ( ) , max_iterations , tolerance ) . sum ( ) ;
let pll_toggle_zero : DVector < f64 > = profile_log_likelihood_obs ( & data , beta . clone ( ) , lambda . clone ( ) , max_iterations , tolerance ) ;
progress_bar . inc ( 1 ) ;
let pll_toggle_one : Vec < DVector < f64 > > = ( 0 .. data . num_covs ( ) ) . into_par_iter ( ) . map ( | j | {
let mut pll_beta = beta . clone ( ) ;
pll_beta [ j ] + = h ;
profile_log_likelihood_obs ( & data , pll_beta , lambda . clone ( ) , max_iterations , tolerance )
} )
. progress_with ( progress_bar . clone ( ) )
. collect ( ) ;
progress_bar . finish ( ) ;
let mut pll_matrix : DMatrix < f64 > = DMatrix ::zeros ( data . num_covs ( ) , data . num_covs ( ) ) ;
for i in 0 .. data . num_obs ( ) {
let toggle_none_i = pll_toggle_zero [ i ] ;
let mut ps_i : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
for p in 0 .. data . num_covs ( ) {
ps_i [ p ] = ( pll_toggle_one [ p ] [ i ] - toggle_none_i ) / h ;
}
pll_matrix + = ps_i . clone ( ) * ps_i . transpose ( ) ;
}
let vcov = pll_matrix . try_inverse ( ) . expect ( " Matrix not invertible " ) ;
// Unstandardise SEs
let beta_se = vcov . diagonal ( ) . apply_into ( | x | { * x = x . sqrt ( ) ; } ) ;
let mut beta_se_unstandardised : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
for ( j , se ) in beta_se . iter ( ) . enumerate ( ) {
beta_se_unstandardised [ j ] = se / indep_stdev [ j ] ;
}
return IntervalCensoredCoxResult {
params : beta_unstandardised . data . as_vec ( ) . clone ( ) ,
params_se : beta_se_unstandardised . data . as_vec ( ) . clone ( ) ,
2023-04-18 16:18:19 +10:00
cumulative_hazard : cumulative_hazard ( & lambda ) . data . as_vec ( ) . clone ( ) ,
cumulative_hazard_times : data . time_points ,
2023-04-17 17:50:43 +10:00
ll_model : ll_model ,
ll_null : ll_null ,
} ;
}
fn do_e_step ( data : & IntervalCensoredCoxData , exp_beta_z : & Matrix1xX < f64 > , lambda : & DVector < f64 > ) -> DMatrix < f64 > {
// Compute S_L and S_R (S_i1 and S_i2 in the paper)
let s_left = e_step_compute_s ( data , & exp_beta_z , lambda , 0 ) ;
let s_right = e_step_compute_s ( data , & exp_beta_z , lambda , 1 ) ;
// In the paper, consideration is given to G(x)
// But in a proportional hazards model, G(x) = x
// So we omit the details
// As a consequence, the posterior ξ_i are always 1
// Compute posterior weights (W_ik, "posterior mean" in C++)
let mut posterior_weight : DMatrix < f64 > = DMatrix ::zeros ( data . num_obs ( ) , data . num_times ( ) ) ;
for ( i , observation ) in data . data_times . column_iter ( ) . enumerate ( ) {
let time_left = observation [ 0 ] ;
let time_right = observation [ 1 ] ;
for ( k , time ) in data . time_points . iter ( ) . enumerate ( ) {
if * time < = time_left {
// t_k <= L_i
posterior_weight [ ( i , k ) ] = 0.0 ;
} else if * time < = time_right & & time_right . is_finite ( ) {
// L_i < t_k <= R_i, with R_i < ∞
// Assumes r = 0
posterior_weight [ ( i , k ) ] = lambda [ k ] * exp_beta_z [ i ] / ( 1.0 - ( s_left [ i ] - s_right [ i ] ) . exp ( ) ) ;
} else {
// None of the above circumstances
// C++ says the weight is unused in this case
// Set this to a non-NaN value so we can still do elementwise vector multiplication for masking
posterior_weight [ ( i , k ) ] = 0.0 ;
}
}
}
return posterior_weight ;
}
fn e_step_compute_s ( data : & IntervalCensoredCoxData , exp_beta_z : & Matrix1xX < f64 > , lambda : & DVector < f64 > , time_index : usize ) -> DVector < f64 > {
let mut s : DVector < f64 > = DVector ::zeros ( data . num_obs ( ) ) ;
for ( i , observation ) in data . data_times . column_iter ( ) . enumerate ( ) {
let time_cutoff = observation [ time_index ] ;
if time_cutoff . is_infinite ( ) {
s [ i ] = f64 ::INFINITY ;
} else {
for ( k , time ) in data . time_points . iter ( ) . enumerate ( ) {
if * time < = time_cutoff {
// time is t_k <= L_i, or t_k <= R_i, as applicable
s [ i ] + = lambda [ k ] * exp_beta_z [ i ] ; // Row 0, because all covariates are time-independent
} else {
break ;
}
}
}
}
return s ;
}
fn do_m_step ( data : & IntervalCensoredCoxData , exp_beta_z : & Matrix1xX < f64 > , beta : & DVector < f64 > , posterior_weight : DMatrix < f64 > ) -> ( DVector < f64 > , DVector < f64 > ) {
// ComputeSummandTerm
// Covariates are time-independent in this model
// And ξ_i is always 1, as discussed above
// So we can skip this step and let xi_exp_beta_z = exp_beta_z
let xi_exp_beta_z = & exp_beta_z ;
// Split these steps into functions to make profiling easier
let ( mut s0 , s1 , s2 ) = m_step_compute_s_values ( data , xi_exp_beta_z ) ;
let sigma = m_step_compute_sigma ( data , & posterior_weight , & s0 , & s1 , & s2 ) ;
let new_beta = m_step_compute_new_beta ( data , & posterior_weight , & s0 , & s1 , sigma , beta ) ;
s0 = m_step_compute_s0 ( data , beta ) ;
let new_lambda = m_step_compute_new_lambda ( data , & posterior_weight , & s0 ) ;
return ( new_beta , new_lambda ) ;
}
fn m_step_compute_s_values ( data : & IntervalCensoredCoxData , xi_exp_beta_z : & Matrix1xX < f64 > ) -> ( DVector < f64 > , Vec < DVector < f64 > > , Vec < DMatrix < f64 > > ) {
// ComputeSValues
// Compute s0
let mut s0 : DVector < f64 > = DVector ::zeros ( data . num_times ( ) ) ; // Elements are f64
for i in 0 .. data . num_obs ( ) {
let s0_contrib = xi_exp_beta_z [ i ] ;
s0 + = data . r_star_indicator . row ( i ) . transpose ( ) * s0_contrib ;
}
// Precompute s1, s2 contributions for each observation
let mut s1_contrib : Vec < DVector < f64 > > = vec! [ DVector ::zeros ( data . num_covs ( ) ) ; data . num_obs ( ) ] ; // Elements are DVector of len num-covariates
let mut s2_contrib : Vec < DMatrix < f64 > > = vec! [ DMatrix ::zeros ( data . num_covs ( ) , data . num_covs ( ) ) ; data . num_obs ( ) ] ; // Elements are (num-covariates, num-covariates)
for i in 0 .. data . num_obs ( ) {
s1_contrib [ i ] = xi_exp_beta_z [ i ] * data . data_indep . column ( i ) ;
s2_contrib [ i ] = xi_exp_beta_z [ i ] * & data . z_z_transpose [ i ] ; // Observations are time-independent
}
let s1 = ( 0 .. data . num_times ( ) ) . into_par_iter ( ) . map ( | k | {
let mut s1_k = DVector ::zeros ( data . num_covs ( ) ) ;
for i in 0 .. data . num_obs ( ) {
if data . r_star_indicator [ ( i , k ) ] = = 1.0 {
s1_k + = & s1_contrib [ i ] ;
}
}
s1_k
} ) . collect ( ) ;
let s2 = ( 0 .. data . num_times ( ) ) . into_par_iter ( ) . map ( | k | {
let mut s2_k = DMatrix ::zeros ( data . num_covs ( ) , data . num_covs ( ) ) ;
for i in 0 .. data . num_obs ( ) {
if data . r_star_indicator [ ( i , k ) ] = = 1.0 {
s2_k + = & s2_contrib [ i ] ;
}
}
s2_k
} ) . collect ( ) ;
return ( s0 , s1 , s2 ) ;
}
fn m_step_compute_sigma ( data : & IntervalCensoredCoxData , posterior_weight : & DMatrix < f64 > , s0 : & DVector < f64 > , s1 : & Vec < DVector < f64 > > , s2 : & Vec < DMatrix < f64 > > ) -> DMatrix < f64 > {
// ComputeSigma
let mut sigma : DMatrix < f64 > = DMatrix ::zeros ( data . num_covs ( ) , data . num_covs ( ) ) ;
for k in 0 .. data . num_times ( ) {
let factor_k = ( s1 [ k ] . clone ( ) / s0 [ k ] ) * ( s1 [ k ] . transpose ( ) / s0 [ k ] ) - ( s2 [ k ] . clone ( ) / s0 [ k ] ) ;
let sum_posterior_weight = data . r_star_indicator . column ( k ) . component_mul ( & posterior_weight . column ( k ) ) . sum ( ) ;
sigma + = sum_posterior_weight * factor_k . clone ( ) ;
}
return sigma ;
}
fn m_step_compute_new_beta ( data : & IntervalCensoredCoxData , posterior_weight : & DMatrix < f64 > , s0 : & DVector < f64 > , s1 : & Vec < DVector < f64 > > , sigma : DMatrix < f64 > , beta : & DVector < f64 > ) -> DVector < f64 > {
// ComputeNewBeta
assert! ( sigma . clone ( ) . full_piv_lu ( ) . is_invertible ( ) , " Sigma is not invertible " ) ;
let mut sum : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
for k in 0 .. data . num_times ( ) {
let quotient_k = s1 [ k ] . clone ( ) / s0 [ k ] ;
for i in 0 .. data . num_obs ( ) {
if data . r_star_indicator [ ( i , k ) ] = = 1.0 {
sum + = posterior_weight [ ( i , k ) ] * ( data . data_indep . column ( i ) - & quotient_k ) ;
}
}
}
let new_beta = beta . clone ( ) - sigma . try_inverse ( ) . unwrap ( ) * sum ;
return new_beta ;
}
fn m_step_compute_s0 ( data : & IntervalCensoredCoxData , beta : & DVector < f64 > ) -> DVector < f64 > {
// ComputeS0
let mut s0 : DVector < f64 > = DVector ::zeros ( data . num_times ( ) ) ;
for i in 0 .. data . num_obs ( ) {
// let s0_contrib = posterior_xi[i] * self.beta.dot(&data_indep.column(i)).exp();
let s0_contrib = beta . dot ( & data . data_indep . column ( i ) ) . exp ( ) ;
s0 + = data . r_star_indicator . row ( i ) . transpose ( ) * s0_contrib ;
}
return s0 ;
}
fn m_step_compute_new_lambda ( data : & IntervalCensoredCoxData , posterior_weight : & DMatrix < f64 > , s0 : & DVector < f64 > ) -> DVector < f64 > {
// ComputeNewLambda
let mut new_lambda : DVector < f64 > = DVector ::zeros ( data . num_times ( ) ) ;
for k in 0 .. data . num_times ( ) {
let mut numerator_k = 0.0 ;
for i in 0 .. data . num_obs ( ) {
if data . r_star_indicator [ ( i , k ) ] = = 1.0 {
numerator_k + = posterior_weight [ ( i , k ) ] ;
}
}
new_lambda [ k ] = numerator_k / s0 [ k ] ;
}
return new_lambda ;
}
fn em_check_convergence ( beta : & DVector < f64 > , lambda : & DVector < f64 > , new_beta : & DVector < f64 > , new_lambda : & DVector < f64 > , tolerance : f64 ) -> ( f64 , bool ) {
let beta_diff = max_abs_difference ( beta , new_beta ) ;
let old_cumulative_hazard = cumulative_hazard ( lambda ) ;
let new_cumulative_hazard = cumulative_hazard ( new_lambda ) ;
let lambda_diff = max_abs_difference ( & old_cumulative_hazard , & new_cumulative_hazard ) ;
let max_diff = beta_diff . max ( lambda_diff ) ;
return ( max_diff , max_diff < tolerance ) ;
}
fn log_likelihood_obs ( data : & IntervalCensoredCoxData , beta : & DVector < f64 > , lambda : & DVector < f64 > ) -> DVector < f64 > {
// Pre-compute exp(β^T * Z_ik)
let exp_beta_z : Matrix1xX < f64 > = ( beta . transpose ( ) * & data . data_indep ) . apply_into ( | x | { * x = x . exp ( ) ; } ) ;
// Compute S_L and S_R (S_i1 and S_i2 in the paper)
let s_left = e_step_compute_s ( data , & exp_beta_z , lambda , 0 ) ;
let s_right = e_step_compute_s ( data , & exp_beta_z , lambda , 1 ) ;
// Compute the log-likelihood by summing log-likelihood for each observation
// Assumes G(x) = x
let mut result = DVector ::zeros ( data . num_obs ( ) ) ;
for i in 0 .. data . num_obs ( ) {
result [ i ] = ( ( - s_left [ i ] ) . exp ( ) - ( - s_right [ i ] ) . exp ( ) ) . ln ( ) ;
}
return result ;
}
fn profile_log_likelihood_obs ( data : & IntervalCensoredCoxData , beta : DVector < f64 > , mut lambda : DVector < f64 > , max_iterations : u32 , tolerance : f64 ) -> DVector < f64 > {
for _iteration in 0 .. max_iterations {
// Pre-compute exp(β^T * Z_ik)
let exp_beta_z : Matrix1xX < f64 > = ( beta . transpose ( ) * & data . data_indep ) . apply_into ( | x | { * x = x . exp ( ) ; } ) ;
// Do E-step
let posterior_weight = do_e_step ( data , & exp_beta_z , & lambda ) ;
// Do M-step (skip expensive unnecessary steps)
let s0 = m_step_compute_s0 ( data , & beta ) ;
let new_lambda = m_step_compute_new_lambda ( data , & posterior_weight , & s0 ) ;
// Check for convergence
let old_cumulative_hazard = cumulative_hazard ( & lambda ) ;
let new_cumulative_hazard = cumulative_hazard ( & new_lambda ) ;
let lambda_diff = max_abs_difference ( & old_cumulative_hazard , & new_cumulative_hazard ) ;
lambda = new_lambda ;
// TODO: Incorporate into progress bar
//println!("Profile iteration {}, estimates changed by {}", iteration + 1, lambda_diff);
if lambda_diff < tolerance {
return log_likelihood_obs ( data , & beta , & lambda ) ;
}
}
panic! ( " Exceeded --max-iterations " ) ;
}
#[ derive(Serialize, Deserialize) ]
2023-04-17 22:12:07 +10:00
pub struct IntervalCensoredCoxResult {
pub params : Vec < f64 > ,
pub params_se : Vec < f64 > ,
2023-04-18 16:18:19 +10:00
pub cumulative_hazard : Vec < f64 > ,
pub cumulative_hazard_times : Vec < f64 > ,
2023-04-17 22:12:07 +10:00
pub ll_model : f64 ,
pub ll_null : f64 ,
2023-04-17 17:50:43 +10:00
// TODO: cumulative hazard, etc.
}
fn cumulative_hazard ( lambda : & DVector < f64 > ) -> DVector < f64 > {
let mut result = DVector ::zeros ( lambda . nrows ( ) ) ;
for ( i , value ) in lambda . iter ( ) . enumerate ( ) {
if i > 0 {
result [ i ] + = result [ i - 1 ] ;
}
result [ i ] + = value ;
}
return result ;
}
fn max_abs_difference ( vector_old : & DVector < f64 > , vector_new : & DVector < f64 > ) -> f64 {
return ( vector_new - vector_old ) . abs ( ) . max ( ) ;
}