2023-04-17 17:50:43 +10:00
// hpstat: High-performance statistics implementations
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
const Z_97_5 : f64 = 1.959964 ; // This is the limit of resolution for an f64
2023-05-01 00:13:32 +10:00
const COL_LEFT : usize = 0 ;
const COL_RIGHT : usize = 1 ;
2023-04-28 01:02:09 +10:00
2023-04-30 15:45:55 +10:00
use core ::mem ::MaybeUninit ;
2023-04-17 17:50:43 +10:00
use std ::io ;
use clap ::{ Args , ValueEnum } ;
2023-04-21 17:39:24 +10:00
use csv ::{ Reader , StringRecord } ;
2023-04-21 17:21:33 +10:00
use indicatif ::{ ParallelProgressIterator , ProgressBar , ProgressDrawTarget , ProgressStyle } ;
2023-05-01 00:13:32 +10:00
use nalgebra ::{ Const , DMatrix , DVector , Dyn , MatrixXx2 } ;
2023-04-17 17:50:43 +10:00
use prettytable ::{ Table , format , row } ;
use rayon ::prelude ::* ;
use serde ::{ Serialize , Deserialize } ;
2023-04-28 01:02:09 +10:00
use crate ::pava ::monotonic_regression_pava ;
2023-04-21 17:21:33 +10:00
use crate ::term ::UnconditionalTermLike ;
2023-04-17 17:50:43 +10:00
#[ derive(Args) ]
pub struct IntCoxArgs {
/// Path to CSV input file containing the observations
#[ arg() ]
input : String ,
/// Output format
#[ arg(long, value_enum, default_value= " text " ) ]
output : OutputFormat ,
2023-04-30 15:46:13 +10:00
/// Maximum number of iterations to attempt
2023-04-28 01:02:09 +10:00
#[ arg(long, default_value= " 1000 " ) ]
2023-04-17 17:50:43 +10:00
max_iterations : u32 ,
2023-04-30 15:46:13 +10:00
/// Terminate algorithm when the absolute change in log-likelihood is less than this tolerance
2023-04-28 01:02:09 +10:00
#[ arg(long, default_value= " 0.01 " ) ]
ll_tolerance : f64 ,
2023-04-17 17:50:43 +10:00
}
#[ derive(ValueEnum, Clone) ]
enum OutputFormat {
Text ,
Json
}
pub fn main ( args : IntCoxArgs ) {
// Read data
2023-04-21 17:39:24 +10:00
let ( indep_names , data_times , data_indep ) = read_data ( & args . input ) ;
2023-04-17 17:50:43 +10:00
// Fit regression
2023-04-21 17:21:33 +10:00
let progress_bar = ProgressBar ::with_draw_target ( Some ( 0 ) , ProgressDrawTarget ::term_like ( Box ::new ( UnconditionalTermLike ::stderr ( ) ) ) ) ;
2023-04-28 01:02:09 +10:00
let result = fit_interval_censored_cox ( data_times , data_indep , progress_bar , args . max_iterations , args . ll_tolerance ) ;
2023-04-17 17:50:43 +10:00
// Display output
match args . output {
OutputFormat ::Text = > {
println! ( ) ;
println! ( ) ;
println! ( " LL-Model = {:.5} " , result . ll_model ) ;
println! ( " LL-Null = {:.5} " , result . ll_null ) ;
let mut summary = Table ::new ( ) ;
let format = format ::FormatBuilder ::new ( )
. separators ( & [ format ::LinePosition ::Top , format ::LinePosition ::Title , format ::LinePosition ::Bottom ] , format ::LineSeparator ::new ( '-' , '+' , '+' , '+' ) )
. padding ( 2 , 2 )
. build ( ) ;
summary . set_format ( format ) ;
summary . set_titles ( row! [ " Parameter " , c ->" β " , c ->" Std Err. " , c ->" exp(β) " , H2c ->" (95% CI) " ] ) ;
for ( i , indep_name ) in indep_names . iter ( ) . enumerate ( ) {
summary . add_row ( row! [
indep_name ,
r ->format ! ( " {:.5} " , result . params [ i ] ) ,
r ->format ! ( " {:.5} " , result . params_se [ i ] ) ,
r ->format ! ( " {:.5} " , result . params [ i ] . exp ( ) ) ,
r ->format ! ( " ({:.5}, " , ( result . params [ i ] - Z_97_5 * result . params_se [ i ] ) . exp ( ) ) ,
format! ( " {:.5} ) " , ( result . params [ i ] + Z_97_5 * result . params_se [ i ] ) . exp ( ) ) ,
] ) ;
}
summary . printstd ( ) ;
}
OutputFormat ::Json = > {
println! ( " {} " , serde_json ::to_string ( & result ) . unwrap ( ) ) ;
}
}
}
2023-05-01 00:13:32 +10:00
pub fn read_data ( path : & str ) -> ( Vec < String > , MatrixXx2 < f64 > , DMatrix < f64 > ) {
2023-04-21 17:39:24 +10:00
// Read CSV into memory
let headers : StringRecord ;
let records : Vec < StringRecord > ;
if path = = " - " {
let mut csv_reader = Reader ::from_reader ( io ::stdin ( ) ) ;
headers = csv_reader . headers ( ) . unwrap ( ) . clone ( ) ;
records = csv_reader . records ( ) . map ( | r | r . unwrap ( ) ) . collect ( ) ;
} else {
let mut csv_reader = Reader ::from_path ( path ) . unwrap ( ) ;
headers = csv_reader . headers ( ) . unwrap ( ) . clone ( ) ;
records = csv_reader . records ( ) . map ( | r | r . unwrap ( ) ) . collect ( ) ;
}
// Read data into matrices
2023-05-01 00:13:32 +10:00
// Note: data_times has one ROW per observation, whereas data_indep has one COLUMN per observation
// See comment in IntervalCensoredCoxData
2023-04-21 17:39:24 +10:00
2023-05-01 00:13:32 +10:00
let mut data_times : MatrixXx2 < MaybeUninit < f64 > > = MatrixXx2 ::uninit (
Dyn ( records . len ( ) ) ,
Const ::< 2 > // Left time, right time
2023-04-21 17:39:24 +10:00
) ;
// Called "Z" in the paper and "X" in the C++ code
2023-04-30 15:45:55 +10:00
let mut data_indep : DMatrix < MaybeUninit < f64 > > = DMatrix ::uninit (
Dyn ( headers . len ( ) - 2 ) ,
Dyn ( records . len ( ) )
2023-04-21 17:39:24 +10:00
) ;
// Parse header row
let indep_names : Vec < String > = headers . iter ( ) . skip ( 2 ) . map ( String ::from ) . collect ( ) ;
// Parse data
for ( i , row ) in records . iter ( ) . enumerate ( ) {
for ( j , item ) in row . iter ( ) . enumerate ( ) {
let value = match item {
" inf " = > f64 ::INFINITY ,
_ = > item . parse ( ) . expect ( " Malformed float " )
} ;
if j < 2 {
2023-05-01 00:13:32 +10:00
data_times [ ( i , j ) ] . write ( value ) ;
2023-04-21 17:39:24 +10:00
} else {
2023-04-30 15:45:55 +10:00
data_indep [ ( j - 2 , i ) ] . write ( value ) ;
2023-04-21 17:39:24 +10:00
}
}
}
2023-04-28 01:02:09 +10:00
// TODO: Fail on left time > right time
// TODO: Fail on left time < 0
2023-04-30 15:45:55 +10:00
// SAFETY: assume_init is OK because we initialised all values above
unsafe {
return ( indep_names , data_times . assume_init ( ) , data_indep . assume_init ( ) ) ;
}
2023-04-21 17:39:24 +10:00
}
2023-04-17 17:50:43 +10:00
struct IntervalCensoredCoxData {
2023-05-01 00:13:32 +10:00
// BEWARE! data_time_indexes has one ROW per observation, whereas data_indep has one COLUMN per observation
// This improves the speed later by avoiding unnecessary matrix transposition
2023-04-28 01:02:09 +10:00
//data_times: DMatrix<f64>,
2023-05-01 00:13:32 +10:00
data_time_indexes : MatrixXx2 < usize > ,
2023-04-17 17:50:43 +10:00
data_indep : DMatrix < f64 > ,
// Cached intermediate values
time_points : Vec < f64 > ,
}
impl IntervalCensoredCoxData {
fn num_obs ( & self ) -> usize {
return self . data_indep . ncols ( ) ;
}
fn num_covs ( & self ) -> usize {
return self . data_indep . nrows ( ) ;
}
fn num_times ( & self ) -> usize {
return self . time_points . len ( ) ;
}
}
2023-05-01 00:13:32 +10:00
pub fn fit_interval_censored_cox ( data_times : MatrixXx2 < f64 > , mut data_indep : DMatrix < f64 > , progress_bar : ProgressBar , max_iterations : u32 , ll_tolerance : f64 ) -> IntervalCensoredCoxResult {
2023-04-17 17:50:43 +10:00
// ----------------------
// Prepare for regression
// Standardise values
let indep_means = data_indep . column_mean ( ) ;
let indep_stdev = data_indep . column_variance ( ) . apply_into ( | x | { * x = ( * x * data_indep . ncols ( ) as f64 / ( data_indep . ncols ( ) - 1 ) as f64 ) . sqrt ( ) ; } ) ;
for j in 0 .. data_indep . nrows ( ) {
data_indep . row_mut ( j ) . apply ( | x | * x = ( * x - indep_means [ j ] ) / indep_stdev [ j ] ) ;
}
// Get time points (t_0 = 0, t_1, ..., t_m)
2023-04-28 01:02:09 +10:00
// TODO: Reimplement Turnbull intervals
2023-04-30 15:59:42 +10:00
let mut time_points : Vec < f64 > = Vec ::with_capacity ( data_times . len ( ) + 1 ) ;
time_points . extend ( data_times . iter ( ) ) ;
2023-04-28 01:02:09 +10:00
time_points . push ( 0.0 ) ; // Ensure 0 is in the list
//time_points.push(f64::INFINITY); // Ensure infinity is on the list
time_points . sort_by ( | a , b | a . partial_cmp ( b ) . unwrap ( ) ) ; // Cannot use .sort() as f64 does not implement Ord
time_points . dedup ( ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
// Recode times as indexes
// TODO: HashMap?
2023-05-01 00:13:32 +10:00
let data_time_indexes = MatrixXx2 ::from_iterator ( data_times . nrows ( ) , data_times . iter ( ) . map ( | t | time_points . iter ( ) . position ( | x | x = = t ) . unwrap ( ) ) ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
// Initialise β, Λ
let mut beta : DVector < f64 > = DVector ::zeros ( data_indep . nrows ( ) ) ;
let mut lambda : DVector < f64 > = DVector ::from_iterator ( time_points . len ( ) , ( 0 .. time_points . len ( ) ) . map ( | i | i as f64 / time_points . len ( ) as f64 ) ) ;
2023-04-17 17:50:43 +10:00
let data = IntervalCensoredCoxData {
2023-04-28 01:02:09 +10:00
//data_times: data_times,
data_time_indexes : data_time_indexes ,
2023-04-17 17:50:43 +10:00
data_indep : data_indep ,
time_points : time_points ,
} ;
// -------------------
2023-04-28 01:02:09 +10:00
// Apply ICM algorithm
let mut exp_z_beta = compute_exp_z_beta ( & data , & beta ) ;
let mut s = compute_s ( & data , & lambda , & exp_z_beta ) ;
let mut ll_model = log_likelihood_obs ( & s ) . sum ( ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
progress_bar . set_style ( ProgressStyle ::with_template ( " [{elapsed_precise}] {bar:40} {msg} " ) . unwrap ( ) ) ;
2023-04-17 17:50:43 +10:00
progress_bar . set_length ( u64 ::MAX ) ;
progress_bar . reset ( ) ;
2023-04-28 01:02:09 +10:00
progress_bar . println ( " Running ICM/NR algorithm to fit interval-censored Cox model " ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
let mut iteration = 1 ;
2023-04-17 17:50:43 +10:00
loop {
2023-04-28 01:02:09 +10:00
// Update lambda
let lambda_new ;
( lambda_new , s , _ ) = update_lambda ( & data , & lambda , & exp_z_beta , & s , ll_model ) ;
// Update beta
let beta_new = update_beta ( & data , & beta , & lambda_new , & exp_z_beta , & s ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
// Compute new log-likelihood
exp_z_beta = compute_exp_z_beta ( & data , & beta_new ) ;
s = compute_s ( & data , & lambda_new , & exp_z_beta ) ;
let ll_model_new = log_likelihood_obs ( & s ) . sum ( ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
let mut converged = true ;
let ll_change = ll_model_new - ll_model ;
if ll_change > ll_tolerance {
converged = false ;
}
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
lambda = lambda_new ;
beta = beta_new ;
ll_model = ll_model_new ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
2023-04-17 17:50:43 +10:00
let progress2 = ( iteration as f64 / max_iterations as f64 * u64 ::MAX as f64 ) as u64 ;
2023-04-28 01:02:09 +10:00
let progress3 = ( ( - ll_change . log10 ( ) ) . max ( 0.0 ) / - ll_tolerance . log10 ( ) * u64 ::MAX as f64 ) as u64 ;
2023-04-23 18:36:28 +10:00
2023-04-28 01:02:09 +10:00
// Update progress bar
progress_bar . set_position ( progress_bar . position ( ) . max ( progress3 . max ( progress2 ) ) ) ;
progress_bar . set_message ( format! ( " Iteration {} (LL = {:.4} , ΔLL = {:.4} ) " , iteration + 1 , ll_model , ll_change ) ) ;
2023-04-17 17:50:43 +10:00
if converged {
2023-04-28 01:02:09 +10:00
progress_bar . println ( format! ( " ICM/NR converged in {} iterations " , iteration ) ) ;
2023-04-17 17:50:43 +10:00
break ;
}
iteration + = 1 ;
2023-04-28 01:02:09 +10:00
if iteration > max_iterations {
2023-04-17 17:50:43 +10:00
panic! ( " Exceeded --max-iterations " ) ;
}
}
// Unstandardise betas
let mut beta_unstandardised : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
for ( j , beta_value ) in beta . iter ( ) . enumerate ( ) {
beta_unstandardised [ j ] = beta_value / indep_stdev [ j ] ;
}
// -------------------------
// Compute covariance matrix
// Compute profile log-likelihoods
let h = 5.0 / ( data . num_obs ( ) as f64 ) . sqrt ( ) ; // "a constant of order n^(-1/2)"
2023-04-28 01:02:09 +10:00
progress_bar . set_style ( ProgressStyle ::with_template ( " [{elapsed_precise}] {bar:40} Profile LL {pos}/{len} " ) . unwrap ( ) ) ;
2023-04-17 17:50:43 +10:00
progress_bar . set_length ( data . num_covs ( ) as u64 + 2 ) ;
progress_bar . reset ( ) ;
progress_bar . println ( " Profiling log-likelihood to compute covariance matrix " ) ;
// ll_null = log-likelihood for null model
// pll_toggle_zero = log-likelihoods for each observation at final beta
// pll_toggle_one = log-likelihoods for each observation at toggled beta
2023-04-28 01:02:09 +10:00
let ll_null = profile_log_likelihood_obs ( & data , DVector ::zeros ( data . num_covs ( ) ) , lambda . clone ( ) , max_iterations , ll_tolerance ) . sum ( ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
let pll_toggle_zero : DVector < f64 > = profile_log_likelihood_obs ( & data , beta . clone ( ) , lambda . clone ( ) , max_iterations , ll_tolerance ) ;
2023-04-17 17:50:43 +10:00
progress_bar . inc ( 1 ) ;
let pll_toggle_one : Vec < DVector < f64 > > = ( 0 .. data . num_covs ( ) ) . into_par_iter ( ) . map ( | j | {
2023-04-28 01:02:09 +10:00
let mut pll_beta = beta . clone ( ) ;
pll_beta [ j ] + = h ;
profile_log_likelihood_obs ( & data , pll_beta , lambda . clone ( ) , max_iterations , ll_tolerance )
} )
. progress_with ( progress_bar . clone ( ) )
. collect ( ) ;
2023-04-17 17:50:43 +10:00
progress_bar . finish ( ) ;
let mut pll_matrix : DMatrix < f64 > = DMatrix ::zeros ( data . num_covs ( ) , data . num_covs ( ) ) ;
for i in 0 .. data . num_obs ( ) {
let toggle_none_i = pll_toggle_zero [ i ] ;
let mut ps_i : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
for p in 0 .. data . num_covs ( ) {
ps_i [ p ] = ( pll_toggle_one [ p ] [ i ] - toggle_none_i ) / h ;
}
pll_matrix + = ps_i . clone ( ) * ps_i . transpose ( ) ;
}
let vcov = pll_matrix . try_inverse ( ) . expect ( " Matrix not invertible " ) ;
// Unstandardise SEs
let beta_se = vcov . diagonal ( ) . apply_into ( | x | { * x = x . sqrt ( ) ; } ) ;
let mut beta_se_unstandardised : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
for ( j , se ) in beta_se . iter ( ) . enumerate ( ) {
beta_se_unstandardised [ j ] = se / indep_stdev [ j ] ;
}
return IntervalCensoredCoxResult {
params : beta_unstandardised . data . as_vec ( ) . clone ( ) ,
params_se : beta_se_unstandardised . data . as_vec ( ) . clone ( ) ,
2023-04-28 01:02:09 +10:00
cumulative_hazard : lambda . data . as_vec ( ) . clone ( ) ,
2023-04-18 16:18:19 +10:00
cumulative_hazard_times : data . time_points ,
2023-04-17 17:50:43 +10:00
ll_model : ll_model ,
ll_null : ll_null ,
} ;
}
2023-04-30 15:45:55 +10:00
macro_rules ! matrix_exp {
( $matrix : expr ) = > {
{
let mut matrix = $matrix ;
//matrix.data.as_mut_slice().par_iter_mut().for_each(|x| *x = x.exp()); // This is actually slower
matrix . apply ( | x | * x = x . exp ( ) ) ;
matrix
}
}
2023-04-17 17:50:43 +10:00
}
2023-04-28 01:02:09 +10:00
fn compute_exp_z_beta ( data : & IntervalCensoredCoxData , beta : & DVector < f64 > ) -> DVector < f64 > {
2023-04-30 15:45:55 +10:00
return matrix_exp! ( data . data_indep . tr_mul ( beta ) ) ;
2023-04-28 01:02:09 +10:00
}
2023-05-01 00:13:32 +10:00
fn compute_s ( data : & IntervalCensoredCoxData , lambda : & DVector < f64 > , exp_z_beta : & DVector < f64 > ) -> MatrixXx2 < f64 > {
let cumulative_hazard = MatrixXx2 ::from_iterator ( data . num_obs ( ) , data . data_time_indexes . iter ( ) . map ( | i | lambda [ * i ] ) ) ; // Cannot use apply() as different data types
2023-04-28 01:02:09 +10:00
2023-05-01 00:13:32 +10:00
let mut s = MatrixXx2 ::zeros ( data . num_obs ( ) ) ;
s . set_column ( COL_LEFT , & matrix_exp! ( ( - exp_z_beta ) . component_mul ( & cumulative_hazard . column ( 0 ) ) ) ) ;
s . set_column ( COL_RIGHT , & matrix_exp! ( ( - exp_z_beta ) . component_mul ( & cumulative_hazard . column ( 1 ) ) ) ) ;
2023-04-17 17:50:43 +10:00
return s ;
}
2023-05-01 00:13:32 +10:00
fn log_likelihood_obs ( s : & MatrixXx2 < f64 > ) -> DVector < f64 > {
return ( s . column ( COL_LEFT ) - s . column ( COL_RIGHT ) ) . apply_into ( | l | * l = l . ln ( ) ) ;
2023-04-17 17:50:43 +10:00
}
2023-05-01 00:13:32 +10:00
fn update_lambda ( data : & IntervalCensoredCoxData , lambda : & DVector < f64 > , exp_z_beta : & DVector < f64 > , s : & MatrixXx2 < f64 > , log_likelihood : f64 ) -> ( DVector < f64 > , MatrixXx2 < f64 > , f64 ) {
2023-04-28 01:02:09 +10:00
// Compute gradient w.r.t. lambda
let mut lambda_gradient : DVector < f64 > = DVector ::zeros ( data . num_times ( ) ) ;
2023-04-17 17:50:43 +10:00
for i in 0 .. data . num_obs ( ) {
2023-05-01 00:13:32 +10:00
let constant_factor = exp_z_beta [ i ] / ( s [ ( i , COL_LEFT ) ] - s [ ( i , COL_RIGHT ) ] ) ;
lambda_gradient [ data . data_time_indexes [ ( i , COL_LEFT ) ] ] - = s [ ( i , COL_LEFT ) ] * constant_factor ;
lambda_gradient [ data . data_time_indexes [ ( i , COL_RIGHT ) ] ] + = s [ ( i , COL_RIGHT ) ] * constant_factor ;
2023-04-17 17:50:43 +10:00
}
2023-04-28 01:02:09 +10:00
// Compute diagonal elements of Hessian w.r.t lambda
let mut lambda_hessdiag : DVector < f64 > = DVector ::zeros ( data . num_times ( ) ) ;
2023-04-17 17:50:43 +10:00
for i in 0 .. data . num_obs ( ) {
2023-05-01 00:13:32 +10:00
// TODO: Vectorise?
let denominator = s [ ( i , COL_LEFT ) ] - s [ ( i , COL_RIGHT ) ] ;
let aij_left = - s [ ( i , COL_LEFT ) ] * exp_z_beta [ i ] ;
let aij_right = s [ ( i , COL_RIGHT ) ] * exp_z_beta [ i ] ;
2023-04-28 01:02:09 +10:00
2023-05-01 00:13:32 +10:00
lambda_hessdiag [ data . data_time_indexes [ ( i , COL_LEFT ) ] ] + = ( - aij_left * exp_z_beta [ i ] ) / denominator - ( aij_left / denominator ) . powi ( 2 ) ;
lambda_hessdiag [ data . data_time_indexes [ ( i , COL_RIGHT ) ] ] + = ( - aij_right * exp_z_beta [ i ] ) / denominator - ( aij_right / denominator ) . powi ( 2 ) ;
2023-04-17 17:50:43 +10:00
}
2023-04-29 17:39:25 +10:00
// Here are the diagonal elements of G, being the negative diagonal elements of the Hessian
let mut lambda_neghessdiag_nonsingular = - lambda_hessdiag ;
lambda_neghessdiag_nonsingular . apply ( | v | * v = * v + 1e-9 ) ; // Add a small epsilon to ensure non-singular
2023-04-28 01:02:09 +10:00
// To invert the diagonal matrix G, we simply have diag(1/diag(G))
2023-04-29 17:39:25 +10:00
let mut lambda_invneghessdiag = lambda_neghessdiag_nonsingular . clone ( ) ;
lambda_invneghessdiag . apply ( | v | * v = 1.0 / * v ) ;
2023-04-28 01:02:09 +10:00
let lambda_nr_factors = lambda_invneghessdiag . component_mul ( & lambda_gradient ) ;
// Take as large a step as possible while the log-likelihood increases
let mut step_size_exponent : i32 = 0 ;
loop {
let step_size = 0.5_ f64 . powi ( step_size_exponent ) ;
let lambda_target = lambda + step_size * & lambda_nr_factors ;
// Do projection step
2023-04-29 17:39:25 +10:00
let mut lambda_new = monotonic_regression_pava ( lambda_target , lambda_neghessdiag_nonsingular . clone ( ) ) ;
2023-04-28 01:02:09 +10:00
lambda_new . apply ( | l | * l = l . max ( 0.0 ) ) ;
// Constrain Λ(0) = 0
lambda_new [ 0 ] = 0.0 ;
let s_new = compute_s ( data , & lambda_new , exp_z_beta ) ;
let log_likelihood_new = log_likelihood_obs ( & s_new ) . sum ( ) ;
if log_likelihood_new > log_likelihood {
return ( lambda_new , s_new , log_likelihood_new ) ;
2023-04-17 17:50:43 +10:00
}
2023-04-28 01:02:09 +10:00
step_size_exponent + = 1 ;
if step_size_exponent > 10 {
// This shouldn't happen unless there is a numeric problem with the gradient/Hessian
panic! ( " ICM fails to increase log-likelihood " ) ;
//return (lambda.clone(), s.clone(), log_likelihood);
2023-04-17 17:50:43 +10:00
}
}
}
2023-05-01 00:13:32 +10:00
fn update_beta ( data : & IntervalCensoredCoxData , beta : & DVector < f64 > , lambda : & DVector < f64 > , exp_z_beta : & DVector < f64 > , s : & MatrixXx2 < f64 > ) -> DVector < f64 > {
2023-04-29 18:29:33 +10:00
// Compute gradient and Hessian w.r.t. beta
2023-04-28 01:02:09 +10:00
let mut beta_gradient : DVector < f64 > = DVector ::zeros ( data . num_covs ( ) ) ;
2023-04-29 18:29:33 +10:00
let mut beta_hessian : DMatrix < f64 > = DMatrix ::zeros ( data . num_covs ( ) , data . num_covs ( ) ) ;
2023-04-28 01:02:09 +10:00
for i in 0 .. data . num_obs ( ) {
2023-04-29 18:29:33 +10:00
// TODO: Can this be vectorised? Seems unlikely however
2023-05-01 00:13:32 +10:00
let bli = s [ ( i , COL_LEFT ) ] * exp_z_beta [ i ] * lambda [ data . data_time_indexes [ ( i , COL_LEFT ) ] ] ;
let bri = s [ ( i , COL_RIGHT ) ] * exp_z_beta [ i ] * lambda [ data . data_time_indexes [ ( i , COL_RIGHT ) ] ] ;
2023-04-29 18:29:33 +10:00
// Gradient
2023-05-01 00:13:32 +10:00
let z_factor = ( bri - bli ) / ( s [ ( i , COL_LEFT ) ] - s [ ( i , COL_RIGHT ) ] ) ;
2023-04-29 17:39:25 +10:00
beta_gradient . axpy ( z_factor , & data . data_indep . column ( i ) , 1.0 ) ; // beta_gradient += z_factor * data.data_indep.column(i);
2023-04-28 01:02:09 +10:00
2023-04-29 18:29:33 +10:00
// Hessian
2023-05-01 00:13:32 +10:00
let mut z_factor = exp_z_beta [ i ] * lambda [ data . data_time_indexes [ ( i , COL_RIGHT ) ] ] * ( s [ ( i , COL_RIGHT ) ] - bri ) ;
z_factor - = exp_z_beta [ i ] * lambda [ data . data_time_indexes [ ( i , COL_LEFT ) ] ] * ( s [ ( i , COL_LEFT ) ] - bli ) ;
z_factor / = s [ ( i , COL_LEFT ) ] - s [ ( i , COL_RIGHT ) ] ;
2023-04-28 01:02:09 +10:00
2023-05-01 00:13:32 +10:00
z_factor - = ( ( bri - bli ) / ( s [ ( i , COL_LEFT ) ] - s [ ( i , COL_RIGHT ) ] ) ) . powi ( 2 ) ;
2023-04-28 01:02:09 +10:00
2023-04-29 17:39:25 +10:00
beta_hessian . syger ( z_factor , & data . data_indep . column ( i ) , & data . data_indep . column ( i ) , 1.0 ) ; // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
2023-04-17 17:50:43 +10:00
}
2023-04-28 01:02:09 +10:00
let mut beta_neghess = - beta_hessian ;
if ! beta_neghess . try_inverse_mut ( ) {
panic! ( " Hessian is not invertible " ) ;
}
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
let beta_new = beta + beta_neghess * beta_gradient ;
return beta_new ;
2023-04-17 17:50:43 +10:00
}
2023-04-28 01:02:09 +10:00
fn profile_log_likelihood_obs ( data : & IntervalCensoredCoxData , beta : DVector < f64 > , mut lambda : DVector < f64 > , max_iterations : u32 , ll_tolerance : f64 ) -> DVector < f64 > {
// -------------------
// Apply ICM algorithm
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
let exp_z_beta = compute_exp_z_beta ( & data , & beta ) ;
let mut s = compute_s ( & data , & lambda , & exp_z_beta ) ;
let mut ll_model = log_likelihood_obs ( & s ) . sum ( ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
let mut iteration = 1 ;
loop {
// Update lambda
let ( lambda_new , ll_model_new ) ;
( lambda_new , s , ll_model_new ) = update_lambda ( & data , & lambda , & exp_z_beta , & s , ll_model ) ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
// [Do not update beta]
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
let mut converged = true ;
if ll_model_new - ll_model > ll_tolerance {
converged = false ;
}
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
lambda = lambda_new ;
ll_model = ll_model_new ;
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
if converged {
return log_likelihood_obs ( & s ) ;
}
2023-04-17 17:50:43 +10:00
2023-04-28 01:02:09 +10:00
iteration + = 1 ;
if iteration > max_iterations {
panic! ( " Exceeded --max-iterations " ) ;
2023-04-17 17:50:43 +10:00
}
}
}
#[ derive(Serialize, Deserialize) ]
2023-04-17 22:12:07 +10:00
pub struct IntervalCensoredCoxResult {
pub params : Vec < f64 > ,
pub params_se : Vec < f64 > ,
2023-04-18 16:18:19 +10:00
pub cumulative_hazard : Vec < f64 > ,
pub cumulative_hazard_times : Vec < f64 > ,
2023-04-17 22:12:07 +10:00
pub ll_model : f64 ,
pub ll_null : f64 ,
2023-04-17 17:50:43 +10:00
}