// hpstat: High-performance statistics implementations // Copyright © 2023 Lee Yingtong Li (RunasSudo) // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64 const ROW_LEFT: usize = 0; const ROW_RIGHT: usize = 1; use std::io; use clap::{Args, ValueEnum}; use csv::{Reader, StringRecord}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; use nalgebra::{DMatrix, DVector, Matrix2xX}; use prettytable::{Table, format, row}; use rayon::prelude::*; use serde::{Serialize, Deserialize}; use crate::pava::monotonic_regression_pava; use crate::term::UnconditionalTermLike; #[derive(Args)] pub struct IntCoxArgs { /// Path to CSV input file containing the observations #[arg()] input: String, /// Output format #[arg(long, value_enum, default_value="text")] output: OutputFormat, /// Maximum number of E-M iterations to attempt #[arg(long, default_value="1000")] max_iterations: u32, /// Terminate E-M algorithm when the absolute change in log-likelihood is less than this tolerance #[arg(long, default_value="0.01")] ll_tolerance: f64, } #[derive(ValueEnum, Clone)] enum OutputFormat { Text, Json } pub fn main(args: IntCoxArgs) { // Read data let (indep_names, data_times, data_indep) = read_data(&args.input); // Fit regression let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); let result = fit_interval_censored_cox(data_times, data_indep, progress_bar, args.max_iterations, args.ll_tolerance); // Display output match args.output { OutputFormat::Text => { println!(); println!(); println!("LL-Model = {:.5}", result.ll_model); println!("LL-Null = {:.5}", result.ll_null); let mut summary = Table::new(); let format = format::FormatBuilder::new() .separators(&[format::LinePosition::Top, format::LinePosition::Title, format::LinePosition::Bottom], format::LineSeparator::new('-', '+', '+', '+')) .padding(2, 2) .build(); summary.set_format(format); summary.set_titles(row!["Parameter", c->"β", c->"Std Err.", c->"exp(β)", H2c->"(95% CI)"]); for (i, indep_name) in indep_names.iter().enumerate() { summary.add_row(row![ indep_name, r->format!("{:.5}", result.params[i]), r->format!("{:.5}", result.params_se[i]), r->format!("{:.5}", result.params[i].exp()), r->format!("({:.5},", (result.params[i] - Z_97_5 * result.params_se[i]).exp()), format!("{:.5})", (result.params[i] + Z_97_5 * result.params_se[i]).exp()), ]); } summary.printstd(); } OutputFormat::Json => { println!("{}", serde_json::to_string(&result).unwrap()); } } } pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { // Read CSV into memory let headers: StringRecord; let records: Vec; if path == "-" { let mut csv_reader = Reader::from_reader(io::stdin()); headers = csv_reader.headers().unwrap().clone(); records = csv_reader.records().map(|r| r.unwrap()).collect(); } else { let mut csv_reader = Reader::from_path(path).unwrap(); headers = csv_reader.headers().unwrap().clone(); records = csv_reader.records().map(|r| r.unwrap()).collect(); } // Read data into matrices let mut data_times: Matrix2xX = Matrix2xX::zeros( //2, // Left time, right time records.len() ); // Called "Z" in the paper and "X" in the C++ code let mut data_indep: DMatrix = DMatrix::zeros( headers.len() - 2, records.len() ); // Parse header row let indep_names: Vec = headers.iter().skip(2).map(String::from).collect(); // Parse data for (i, row) in records.iter().enumerate() { for (j, item) in row.iter().enumerate() { let value = match item { "inf" => f64::INFINITY, _ => item.parse().expect("Malformed float") }; if j < 2 { data_times[(j, i)] = value; } else { data_indep[(j - 2, i)] = value; } } } // TODO: Fail on left time > right time // TODO: Fail on left time < 0 return (indep_names, data_times, data_indep); } struct IntervalCensoredCoxData { //data_times: DMatrix, data_time_indexes: Matrix2xX, data_indep: DMatrix, // Cached intermediate values time_points: Vec, } impl IntervalCensoredCoxData { fn num_obs(&self) -> usize { return self.data_indep.ncols(); } fn num_covs(&self) -> usize { return self.data_indep.nrows(); } fn num_times(&self) -> usize { return self.time_points.len(); } } pub fn fit_interval_censored_cox(data_times: Matrix2xX, mut data_indep: DMatrix, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult { // ---------------------- // Prepare for regression // Standardise values let indep_means = data_indep.column_mean(); let indep_stdev = data_indep.column_variance().apply_into(|x| { *x = (*x * data_indep.ncols() as f64 / (data_indep.ncols() - 1) as f64).sqrt(); }); for j in 0..data_indep.nrows() { data_indep.row_mut(j).apply(|x| *x = (*x - indep_means[j]) / indep_stdev[j]); } // Get time points (t_0 = 0, t_1, ..., t_m) // TODO: Reimplement Turnbull intervals let mut time_points: Vec; time_points = data_times.iter().copied().collect(); time_points.push(0.0); // Ensure 0 is in the list //time_points.push(f64::INFINITY); // Ensure infinity is on the list time_points.sort_by(|a, b| a.partial_cmp(b).unwrap()); // Cannot use .sort() as f64 does not implement Ord time_points.dedup(); // Recode times as indexes // TODO: HashMap? let data_time_indexes = Matrix2xX::from_iterator(data_times.ncols(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap())); // Initialise β, Λ let mut beta: DVector = DVector::zeros(data_indep.nrows()); let mut lambda: DVector = DVector::from_iterator(time_points.len(), (0..time_points.len()).map(|i| i as f64 / time_points.len() as f64)); let data = IntervalCensoredCoxData { //data_times: data_times, data_time_indexes: data_time_indexes, data_indep: data_indep, time_points: time_points, }; // ------------------- // Apply ICM algorithm let mut exp_z_beta = compute_exp_z_beta(&data, &beta); let mut s = compute_s(&data, &lambda, &exp_z_beta); let mut ll_model = log_likelihood_obs(&s).sum(); progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap()); progress_bar.set_length(u64::MAX); progress_bar.reset(); progress_bar.println("Running ICM/NR algorithm to fit interval-censored Cox model"); let mut iteration = 1; loop { // Update lambda let lambda_new; (lambda_new, s, _) = update_lambda(&data, &lambda, &exp_z_beta, &s, ll_model); // Update beta let beta_new = update_beta(&data, &beta, &lambda_new, &exp_z_beta, &s); // Compute new log-likelihood exp_z_beta = compute_exp_z_beta(&data, &beta_new); s = compute_s(&data, &lambda_new, &exp_z_beta); let ll_model_new = log_likelihood_obs(&s).sum(); let mut converged = true; let ll_change = ll_model_new - ll_model; if ll_change > ll_tolerance { converged = false; } lambda = lambda_new; beta = beta_new; ll_model = ll_model_new; // Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64; let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance.log10() * u64::MAX as f64) as u64; // Update progress bar progress_bar.set_position(progress_bar.position().max(progress3.max(progress2))); progress_bar.set_message(format!("Iteration {} (LL = {:.4}, ΔLL = {:.4})", iteration + 1, ll_model, ll_change)); if converged { progress_bar.println(format!("ICM/NR converged in {} iterations", iteration)); break; } iteration += 1; if iteration > max_iterations { panic!("Exceeded --max-iterations"); } } // Unstandardise betas let mut beta_unstandardised: DVector = DVector::zeros(data.num_covs()); for (j, beta_value) in beta.iter().enumerate() { beta_unstandardised[j] = beta_value / indep_stdev[j]; } // ------------------------- // Compute covariance matrix // Compute profile log-likelihoods let h = 5.0 / (data.num_obs() as f64).sqrt(); // "a constant of order n^(-1/2)" progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Profile LL {pos}/{len}").unwrap()); progress_bar.set_length(data.num_covs() as u64 + 2); progress_bar.reset(); progress_bar.println("Profiling log-likelihood to compute covariance matrix"); // ll_null = log-likelihood for null model // pll_toggle_zero = log-likelihoods for each observation at final beta // pll_toggle_one = log-likelihoods for each observation at toggled beta let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, ll_tolerance).sum(); let pll_toggle_zero: DVector = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, ll_tolerance); progress_bar.inc(1); let pll_toggle_one: Vec> = (0..data.num_covs()).into_par_iter().map(|j| { let mut pll_beta = beta.clone(); pll_beta[j] += h; profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, ll_tolerance) }) .progress_with(progress_bar.clone()) .collect(); progress_bar.finish(); let mut pll_matrix: DMatrix = DMatrix::zeros(data.num_covs(), data.num_covs()); for i in 0..data.num_obs() { let toggle_none_i = pll_toggle_zero[i]; let mut ps_i: DVector = DVector::zeros(data.num_covs()); for p in 0..data.num_covs() { ps_i[p] = (pll_toggle_one[p][i] - toggle_none_i) / h; } pll_matrix += ps_i.clone() * ps_i.transpose(); } let vcov = pll_matrix.try_inverse().expect("Matrix not invertible"); // Unstandardise SEs let beta_se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); }); let mut beta_se_unstandardised: DVector = DVector::zeros(data.num_covs()); for (j, se) in beta_se.iter().enumerate() { beta_se_unstandardised[j] = se / indep_stdev[j]; } return IntervalCensoredCoxResult { params: beta_unstandardised.data.as_vec().clone(), params_se: beta_se_unstandardised.data.as_vec().clone(), cumulative_hazard: lambda.data.as_vec().clone(), cumulative_hazard_times: data.time_points, ll_model: ll_model, ll_null: ll_null, }; } /// Use with Matrix.apply_into for exponentiation fn matrix_exp(v: &mut f64) { *v = v.exp(); } fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector) -> DVector { return data.data_indep.tr_mul(beta).apply_into(matrix_exp); } fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_beta: &DVector) -> Matrix2xX { let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i])); let mut s = Matrix2xX::zeros(data.num_obs()); s.set_row(0, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)).apply_into(matrix_exp)); s.set_row(1, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)).apply_into(matrix_exp)); return s; } fn log_likelihood_obs(s: &Matrix2xX) -> DVector { return (s.row(0) - s.row(1)).apply_into(|l| *l = l.ln()).transpose(); } fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_beta: &DVector, s: &Matrix2xX, log_likelihood: f64) -> (DVector, Matrix2xX, f64) { // Compute gradient w.r.t. lambda let mut lambda_gradient: DVector = DVector::zeros(data.num_times()); for i in 0..data.num_obs() { let constant_factor = exp_z_beta[i] / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]); lambda_gradient[data.data_time_indexes[(ROW_LEFT, i)]] -= s[(ROW_LEFT, i)] * constant_factor; lambda_gradient[data.data_time_indexes[(ROW_RIGHT, i)]] += s[(ROW_RIGHT, i)] * constant_factor; } // Compute diagonal elements of Hessian w.r.t lambda let mut lambda_hessdiag: DVector = DVector::zeros(data.num_times()); for i in 0..data.num_obs() { let denominator = s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]; let aij_left = -s[(ROW_LEFT, i)] * exp_z_beta[i]; let aij_right = s[(ROW_RIGHT, i)] * exp_z_beta[i]; lambda_hessdiag[data.data_time_indexes[(ROW_LEFT, i)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2); lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2); } // Here are the diagonal elements of G, being the negative diagonal elements of the Hessian let mut lambda_neghessdiag_nonsingular = -lambda_hessdiag; lambda_neghessdiag_nonsingular.apply(|v| *v = *v + 1e-9); // Add a small epsilon to ensure non-singular // To invert the diagonal matrix G, we simply have diag(1/diag(G)) let mut lambda_invneghessdiag = lambda_neghessdiag_nonsingular.clone(); lambda_invneghessdiag.apply(|v| *v = 1.0 / *v); let lambda_nr_factors = lambda_invneghessdiag.component_mul(&lambda_gradient); // Take as large a step as possible while the log-likelihood increases let mut step_size_exponent: i32 = 0; loop { let step_size = 0.5_f64.powi(step_size_exponent); let lambda_target = lambda + step_size * &lambda_nr_factors; // Do projection step let mut lambda_new = monotonic_regression_pava(lambda_target, lambda_neghessdiag_nonsingular.clone()); lambda_new.apply(|l| *l = l.max(0.0)); // Constrain Λ(0) = 0 lambda_new[0] = 0.0; let s_new = compute_s(data, &lambda_new, exp_z_beta); let log_likelihood_new = log_likelihood_obs(&s_new).sum(); if log_likelihood_new > log_likelihood { return (lambda_new, s_new, log_likelihood_new); } step_size_exponent += 1; if step_size_exponent > 10 { // This shouldn't happen unless there is a numeric problem with the gradient/Hessian panic!("ICM fails to increase log-likelihood"); //return (lambda.clone(), s.clone(), log_likelihood); } } } fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector, lambda: &DVector, exp_z_beta: &DVector, s: &Matrix2xX) -> DVector { // Compute gradient w.r.t. beta let mut beta_gradient: DVector = DVector::zeros(data.num_covs()); for i in 0..data.num_obs() { // TODO: Vectorise let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]]; let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]]; let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]); beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i); } // Compute Hessian w.r.t. beta let mut beta_hessian: DMatrix = DMatrix::zeros(data.num_covs(), data.num_covs()); for i in 0..data.num_obs() { // TODO: Vectorise // TODO: bli, bri same as above let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]]; let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]]; let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri); z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli); z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]; z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2); beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose(); } let mut beta_neghess = -beta_hessian; if !beta_neghess.try_inverse_mut() { panic!("Hessian is not invertible"); } let beta_new = beta + beta_neghess * beta_gradient; return beta_new; } fn profile_log_likelihood_obs(data: &IntervalCensoredCoxData, beta: DVector, mut lambda: DVector, max_iterations: u32, ll_tolerance: f64) -> DVector { // ------------------- // Apply ICM algorithm let exp_z_beta = compute_exp_z_beta(&data, &beta); let mut s = compute_s(&data, &lambda, &exp_z_beta); let mut ll_model = log_likelihood_obs(&s).sum(); let mut iteration = 1; loop { // Update lambda let (lambda_new, ll_model_new); (lambda_new, s, ll_model_new) = update_lambda(&data, &lambda, &exp_z_beta, &s, ll_model); // [Do not update beta] let mut converged = true; if ll_model_new - ll_model > ll_tolerance { converged = false; } lambda = lambda_new; ll_model = ll_model_new; if converged { return log_likelihood_obs(&s); } iteration += 1; if iteration > max_iterations { panic!("Exceeded --max-iterations"); } } } #[derive(Serialize, Deserialize)] pub struct IntervalCensoredCoxResult { pub params: Vec, pub params_se: Vec, pub cumulative_hazard: Vec, pub cumulative_hazard_times: Vec, pub ll_model: f64, pub ll_null: f64, }