hpstat/src/intcox.rs

601 lines
21 KiB
Rust
Raw Normal View History

2023-04-17 17:50:43 +10:00
// hpstat: High-performance statistics implementations
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
use std::fs;
use std::io;
use clap::{Args, ValueEnum};
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle};
use nalgebra::{DMatrix, DVector, Matrix1xX};
use prettytable::{Table, format, row};
use rayon::prelude::*;
use serde::{Serialize, Deserialize};
#[derive(Args)]
pub struct IntCoxArgs {
/// Path to CSV input file containing the observations
#[arg()]
input: String,
/// Output format
#[arg(long, value_enum, default_value="text")]
output: OutputFormat,
/// Maximum number of E-M iterations to attempt
#[arg(long, default_value="100")]
max_iterations: u32,
/// Terminate E-M algorithm when the maximum absolute change in all parameters is less than this tolerance
#[arg(long, default_value="0.001")]
tolerance: f64,
/// Estimate baseline hazard function using Turnbull innermost intervals
#[arg(long)]
reduced: bool,
}
#[derive(ValueEnum, Clone)]
enum OutputFormat {
Text,
Json
}
pub fn main(args: IntCoxArgs) {
let lines: Vec<String>;
if args.input == "-" {
lines = io::stdin().lines().map(|l| l.unwrap()).collect();
} else {
let contents = fs::read_to_string(args.input).unwrap();
lines = contents.trim_end().split("\n").map(|s| s.to_string()).collect();
}
// Read data into matrices
let mut data_times: DMatrix<f64> = DMatrix::zeros(
2, // Left time, right time
lines.len() - 1 // Minus 1 row for header row
);
// Called "Z" in the paper and "X" in the C++ code
let mut data_indep: DMatrix<f64> = DMatrix::zeros(
lines[0].split(",").count() - 2,
lines.len() - 1 // Minus 1 row for header row
);
// Read header row
let indep_names: Vec<&str> = lines[0].split(",").skip(2).collect();
// Read data
// FIXME: Parse CSV more robustly
for (i, row) in lines.iter().skip(1).enumerate() {
for (j, item) in row.split(",").enumerate() {
let value = match item {
"inf" => f64::INFINITY,
_ => item.parse().expect("Malformed float")
};
if j < 2 {
data_times[(j, i)] = value;
} else {
data_indep[(j - 2, i)] = value;
}
}
}
// Fit regression
let progress_bar = match args.output {
OutputFormat::Text => ProgressBar::new(0),
OutputFormat::Json => ProgressBar::hidden(),
};
let result = fit_interval_censored_cox(data_times, data_indep, args.max_iterations, args.tolerance, args.reduced, progress_bar);
// Display output
match args.output {
OutputFormat::Text => {
println!();
println!();
println!("LL-Model = {:.5}", result.ll_model);
println!("LL-Null = {:.5}", result.ll_null);
let mut summary = Table::new();
let format = format::FormatBuilder::new()
.separators(&[format::LinePosition::Top, format::LinePosition::Title, format::LinePosition::Bottom], format::LineSeparator::new('-', '+', '+', '+'))
.padding(2, 2)
.build();
summary.set_format(format);
summary.set_titles(row!["Parameter", c->"β", c->"Std Err.", c->"exp(β)", H2c->"(95% CI)"]);
for (i, indep_name) in indep_names.iter().enumerate() {
summary.add_row(row![
indep_name,
r->format!("{:.5}", result.params[i]),
r->format!("{:.5}", result.params_se[i]),
r->format!("{:.5}", result.params[i].exp()),
r->format!("({:.5},", (result.params[i] - Z_97_5 * result.params_se[i]).exp()),
format!("{:.5})", (result.params[i] + Z_97_5 * result.params_se[i]).exp()),
]);
}
summary.printstd();
}
OutputFormat::Json => {
println!("{}", serde_json::to_string(&result).unwrap());
}
}
}
struct IntervalCensoredCoxData {
data_times: DMatrix<f64>,
data_indep: DMatrix<f64>,
// Cached intermediate values
time_points: Vec<f64>,
r_star_indicator: DMatrix<f64>,
z_z_transpose: Vec<DMatrix<f64>>,
}
impl IntervalCensoredCoxData {
fn num_obs(&self) -> usize {
return self.data_indep.ncols();
}
fn num_covs(&self) -> usize {
return self.data_indep.nrows();
}
fn num_times(&self) -> usize {
return self.time_points.len();
}
}
2023-04-17 22:12:07 +10:00
pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult {
2023-04-17 17:50:43 +10:00
// ----------------------
// Prepare for regression
// Standardise values
let indep_means = data_indep.column_mean();
let indep_stdev = data_indep.column_variance().apply_into(|x| { *x = (*x * data_indep.ncols() as f64 / (data_indep.ncols() - 1) as f64).sqrt(); });
for j in 0..data_indep.nrows() {
data_indep.row_mut(j).apply(|x| *x = (*x - indep_means[j]) / indep_stdev[j]);
}
// Get time points (t_0 = 0, t_1, ..., t_m)
let mut time_points: Vec<f64>;
if reduced {
// Turnbull intervals
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
all_time_points.extend(data_times.row(0).iter().map(|t| (*t, true)));
all_time_points.extend(data_times.row(1).iter().map(|t| (*t, false)));
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
time_points = Vec::new();
for i in 1..all_time_points.len() {
if all_time_points[i - 1].1 == true && all_time_points[i].1 == false {
time_points.push(all_time_points[i - 1].0);
time_points.push(all_time_points[i].0);
}
}
time_points.push(0.0); // Ensure 0 is in the list
time_points.retain(|t| t.is_finite()); // Remove infinity
time_points.sort_by(|a, b| a.partial_cmp(b).unwrap()); // Cannot use .sort() as f64 does not implement Ord
time_points.dedup();
} else {
// All observed intervals
time_points = data_times.iter().copied().collect();
time_points.push(0.0); // Ensure 0 is in the list
time_points.retain(|t| t.is_finite()); // Remove infinity
time_points.sort_by(|a, b| a.partial_cmp(b).unwrap()); // Cannot use .sort() as f64 does not implement Ord
time_points.dedup();
}
// Initialise β, λ
let mut beta = DVector::zeros(data_indep.nrows());
let mut lambda = DVector::repeat(time_points.len(), 1.0 / (time_points.len() - 1) as f64);
// Compute I(t_k <= R*_i)
// Where R*_i is R_i if R_i ≠ ∞, otherwise it is L_i
let mut r_star_indicator = DMatrix::zeros(data_indep.ncols(), time_points.len());
for (i, observation) in data_times.column_iter().enumerate() {
let time_right_star = if observation[1].is_finite() { observation[1] } else { observation[0] };
for (k, time) in time_points.iter().enumerate() {
if *time <= time_right_star {
// t_k <= R*_i
r_star_indicator[(i, k)] = 1.0;
} else {
r_star_indicator[(i, k)] = 0.0;
}
}
}
// Pre-compute Z * Z^T
// Indexed by observation -> Matrix (num-covariates, num-covariates)
let mut z_z_transpose: Vec<DMatrix<f64>> = Vec::new();
for i in 0..data_indep.ncols() {
let covariates = data_indep.column(i);
z_z_transpose.push(covariates * covariates.transpose());
}
let data = IntervalCensoredCoxData {
data_times: data_times,
data_indep: data_indep,
time_points: time_points,
r_star_indicator: r_star_indicator,
z_z_transpose: z_z_transpose,
};
// -------------------
// Apply E-M algorithm
progress_bar.set_length(u64::MAX);
progress_bar.reset();
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap());
progress_bar.println("Running E-M algorithm to fit interval-censored Cox model");
let mut iteration: u32 = 0;
loop {
// Pre-compute exp(β^T * Z_ik)
let exp_beta_z: Matrix1xX<f64> = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); });
// Do E-step
let posterior_weight = do_e_step(&data, &exp_beta_z, &lambda);
// Do M-step
let (new_beta, new_lambda) = do_m_step(&data, &exp_beta_z, &beta, posterior_weight);
// Check for convergence
let (coef_change, converged) = em_check_convergence(&beta, &lambda, &new_beta, &new_lambda, tolerance);
beta = new_beta;
lambda = new_lambda;
// Update progress bar
// Estimate progress according to either the order of magnitude of the coef_change relative to tolerance, or iteration/max_iterations
let progress1 = ((-coef_change.log10()).max(0.0) / -tolerance.log10() * u64::MAX as f64) as u64;
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
progress_bar.set_position(progress_bar.position().max(progress1).max(progress2));
progress_bar.set_message(format!("Iter {} (delta = {:.6})", iteration + 1, coef_change));
if converged {
progress_bar.finish();
break;
}
iteration += 1;
if iteration >= max_iterations {
panic!("Exceeded --max-iterations");
}
}
// Compute log-likelihood
let ll_model = log_likelihood_obs(&data, &beta, &lambda).sum();
// Unstandardise betas
let mut beta_unstandardised: DVector<f64> = DVector::zeros(data.num_covs());
for (j, beta_value) in beta.iter().enumerate() {
beta_unstandardised[j] = beta_value / indep_stdev[j];
}
// -------------------------
// Compute covariance matrix
// Compute profile log-likelihoods
let h = 5.0 / (data.num_obs() as f64).sqrt(); // "a constant of order n^(-1/2)"
progress_bar.set_length(data.num_covs() as u64 + 2);
progress_bar.reset();
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Profile LL {pos}/{len}").unwrap());
progress_bar.println("Profiling log-likelihood to compute covariance matrix");
// ll_null = log-likelihood for null model
// pll_toggle_zero = log-likelihoods for each observation at final beta
// pll_toggle_one = log-likelihoods for each observation at toggled beta
let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, tolerance).sum();
let pll_toggle_zero: DVector<f64> = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, tolerance);
progress_bar.inc(1);
let pll_toggle_one: Vec<DVector<f64>> = (0..data.num_covs()).into_par_iter().map(|j| {
let mut pll_beta = beta.clone();
pll_beta[j] += h;
profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, tolerance)
})
.progress_with(progress_bar.clone())
.collect();
progress_bar.finish();
let mut pll_matrix: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
for i in 0..data.num_obs() {
let toggle_none_i = pll_toggle_zero[i];
let mut ps_i: DVector<f64> = DVector::zeros(data.num_covs());
for p in 0..data.num_covs() {
ps_i[p] = (pll_toggle_one[p][i] - toggle_none_i) / h;
}
pll_matrix += ps_i.clone() * ps_i.transpose();
}
let vcov = pll_matrix.try_inverse().expect("Matrix not invertible");
// Unstandardise SEs
let beta_se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); });
let mut beta_se_unstandardised: DVector<f64> = DVector::zeros(data.num_covs());
for (j, se) in beta_se.iter().enumerate() {
beta_se_unstandardised[j] = se / indep_stdev[j];
}
return IntervalCensoredCoxResult {
params: beta_unstandardised.data.as_vec().clone(),
params_se: beta_se_unstandardised.data.as_vec().clone(),
cumulative_hazard: cumulative_hazard(&lambda).data.as_vec().clone(),
cumulative_hazard_times: data.time_points,
2023-04-17 17:50:43 +10:00
ll_model: ll_model,
ll_null: ll_null,
};
}
fn do_e_step(data: &IntervalCensoredCoxData, exp_beta_z: &Matrix1xX<f64>, lambda: &DVector<f64>) -> DMatrix<f64> {
// Compute S_L and S_R (S_i1 and S_i2 in the paper)
let s_left = e_step_compute_s(data, &exp_beta_z, lambda, 0);
let s_right = e_step_compute_s(data, &exp_beta_z, lambda, 1);
// In the paper, consideration is given to G(x)
// But in a proportional hazards model, G(x) = x
// So we omit the details
// As a consequence, the posterior ξ_i are always 1
// Compute posterior weights (W_ik, "posterior mean" in C++)
let mut posterior_weight: DMatrix<f64> = DMatrix::zeros(data.num_obs(), data.num_times());
for (i, observation) in data.data_times.column_iter().enumerate() {
let time_left = observation[0];
let time_right = observation[1];
for (k, time) in data.time_points.iter().enumerate() {
if *time <= time_left {
// t_k <= L_i
posterior_weight[(i, k)] = 0.0;
} else if *time <= time_right && time_right.is_finite() {
// L_i < t_k <= R_i, with R_i < ∞
// Assumes r = 0
posterior_weight[(i, k)] = lambda[k] * exp_beta_z[i] / (1.0 - (s_left[i] - s_right[i]).exp());
} else {
// None of the above circumstances
// C++ says the weight is unused in this case
// Set this to a non-NaN value so we can still do elementwise vector multiplication for masking
posterior_weight[(i, k)] = 0.0;
}
}
}
return posterior_weight;
}
fn e_step_compute_s(data: &IntervalCensoredCoxData, exp_beta_z: &Matrix1xX<f64>, lambda: &DVector<f64>, time_index: usize) -> DVector<f64> {
let mut s: DVector<f64> = DVector::zeros(data.num_obs());
for (i, observation) in data.data_times.column_iter().enumerate() {
let time_cutoff = observation[time_index];
if time_cutoff.is_infinite() {
s[i] = f64::INFINITY;
} else {
for (k, time) in data.time_points.iter().enumerate() {
if *time <= time_cutoff {
// time is t_k <= L_i, or t_k <= R_i, as applicable
s[i] += lambda[k] * exp_beta_z[i]; // Row 0, because all covariates are time-independent
} else {
break;
}
}
}
}
return s;
}
fn do_m_step(data: &IntervalCensoredCoxData, exp_beta_z: &Matrix1xX<f64>, beta: &DVector<f64>, posterior_weight: DMatrix<f64>) -> (DVector<f64>, DVector<f64>) {
// ComputeSummandTerm
// Covariates are time-independent in this model
// And ξ_i is always 1, as discussed above
// So we can skip this step and let xi_exp_beta_z = exp_beta_z
let xi_exp_beta_z = &exp_beta_z;
// Split these steps into functions to make profiling easier
let (mut s0, s1, s2) = m_step_compute_s_values(data, xi_exp_beta_z);
let sigma = m_step_compute_sigma(data, &posterior_weight, &s0, &s1, &s2);
let new_beta = m_step_compute_new_beta(data, &posterior_weight, &s0, &s1, sigma, beta);
s0 = m_step_compute_s0(data, beta);
let new_lambda = m_step_compute_new_lambda(data, &posterior_weight, &s0);
return (new_beta, new_lambda);
}
fn m_step_compute_s_values(data: &IntervalCensoredCoxData, xi_exp_beta_z: &Matrix1xX<f64>) -> (DVector<f64>, Vec<DVector<f64>>, Vec<DMatrix<f64>>) {
// ComputeSValues
// Compute s0
let mut s0: DVector<f64> = DVector::zeros(data.num_times()); // Elements are f64
for i in 0..data.num_obs() {
let s0_contrib = xi_exp_beta_z[i];
s0 += data.r_star_indicator.row(i).transpose() * s0_contrib;
}
// Precompute s1, s2 contributions for each observation
let mut s1_contrib: Vec<DVector<f64>> = vec![DVector::zeros(data.num_covs()); data.num_obs()]; // Elements are DVector of len num-covariates
let mut s2_contrib: Vec<DMatrix<f64>> = vec![DMatrix::zeros(data.num_covs(), data.num_covs()); data.num_obs()]; // Elements are (num-covariates, num-covariates)
for i in 0..data.num_obs() {
s1_contrib[i] = xi_exp_beta_z[i] * data.data_indep.column(i);
s2_contrib[i] = xi_exp_beta_z[i] * &data.z_z_transpose[i]; // Observations are time-independent
}
let s1 = (0..data.num_times()).into_par_iter().map(|k| {
let mut s1_k = DVector::zeros(data.num_covs());
for i in 0..data.num_obs() {
if data.r_star_indicator[(i, k)] == 1.0 {
s1_k += &s1_contrib[i];
}
}
s1_k
}).collect();
let s2 = (0..data.num_times()).into_par_iter().map(|k| {
let mut s2_k = DMatrix::zeros(data.num_covs(), data.num_covs());
for i in 0..data.num_obs() {
if data.r_star_indicator[(i, k)] == 1.0 {
s2_k += &s2_contrib[i];
}
}
s2_k
}).collect();
return (s0, s1, s2);
}
fn m_step_compute_sigma(data: &IntervalCensoredCoxData, posterior_weight: &DMatrix<f64>, s0: &DVector<f64>, s1: &Vec<DVector<f64>>, s2: &Vec<DMatrix<f64>>) -> DMatrix<f64> {
// ComputeSigma
let mut sigma: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
for k in 0..data.num_times() {
let factor_k = (s1[k].clone() / s0[k]) * (s1[k].transpose() / s0[k]) - (s2[k].clone() / s0[k]);
let sum_posterior_weight = data.r_star_indicator.column(k).component_mul(&posterior_weight.column(k)).sum();
sigma += sum_posterior_weight * factor_k.clone();
}
return sigma;
}
fn m_step_compute_new_beta(data: &IntervalCensoredCoxData, posterior_weight: &DMatrix<f64>, s0: &DVector<f64>, s1: &Vec<DVector<f64>>, sigma: DMatrix<f64>, beta: &DVector<f64>) -> DVector<f64> {
// ComputeNewBeta
assert!(sigma.clone().full_piv_lu().is_invertible(), "Sigma is not invertible");
let mut sum: DVector<f64> = DVector::zeros(data.num_covs());
for k in 0..data.num_times() {
let quotient_k = s1[k].clone() / s0[k];
for i in 0..data.num_obs() {
if data.r_star_indicator[(i, k)] == 1.0 {
sum += posterior_weight[(i, k)] * (data.data_indep.column(i) - &quotient_k);
}
}
}
let new_beta = beta.clone() - sigma.try_inverse().unwrap() * sum;
return new_beta;
}
fn m_step_compute_s0(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
// ComputeS0
let mut s0: DVector<f64> = DVector::zeros(data.num_times());
for i in 0..data.num_obs() {
// let s0_contrib = posterior_xi[i] * self.beta.dot(&data_indep.column(i)).exp();
let s0_contrib = beta.dot(&data.data_indep.column(i)).exp();
s0 += data.r_star_indicator.row(i).transpose() * s0_contrib;
}
return s0;
}
fn m_step_compute_new_lambda(data: &IntervalCensoredCoxData, posterior_weight: &DMatrix<f64>, s0: &DVector<f64>) -> DVector<f64> {
// ComputeNewLambda
let mut new_lambda: DVector<f64> = DVector::zeros(data.num_times());
for k in 0..data.num_times() {
let mut numerator_k = 0.0;
for i in 0..data.num_obs() {
if data.r_star_indicator[(i, k)] == 1.0 {
numerator_k += posterior_weight[(i, k)];
}
}
new_lambda[k] = numerator_k / s0[k];
}
return new_lambda;
}
fn em_check_convergence(beta: &DVector<f64>, lambda: &DVector<f64>, new_beta: &DVector<f64>, new_lambda: &DVector<f64>, tolerance: f64) -> (f64, bool) {
let beta_diff = max_abs_difference(beta, new_beta);
let old_cumulative_hazard = cumulative_hazard(lambda);
let new_cumulative_hazard = cumulative_hazard(new_lambda);
let lambda_diff = max_abs_difference(&old_cumulative_hazard, &new_cumulative_hazard);
let max_diff = beta_diff.max(lambda_diff);
return (max_diff, max_diff < tolerance);
}
fn log_likelihood_obs(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>) -> DVector<f64> {
// Pre-compute exp(β^T * Z_ik)
let exp_beta_z: Matrix1xX<f64> = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); });
// Compute S_L and S_R (S_i1 and S_i2 in the paper)
let s_left = e_step_compute_s(data, &exp_beta_z, lambda, 0);
let s_right = e_step_compute_s(data, &exp_beta_z, lambda, 1);
// Compute the log-likelihood by summing log-likelihood for each observation
// Assumes G(x) = x
let mut result = DVector::zeros(data.num_obs());
for i in 0..data.num_obs() {
result[i] = ((-s_left[i]).exp() - (-s_right[i]).exp()).ln();
}
return result;
}
fn profile_log_likelihood_obs(data: &IntervalCensoredCoxData, beta: DVector<f64>, mut lambda: DVector<f64>, max_iterations: u32, tolerance: f64) -> DVector<f64> {
for _iteration in 0..max_iterations {
// Pre-compute exp(β^T * Z_ik)
let exp_beta_z: Matrix1xX<f64> = (beta.transpose() * &data.data_indep).apply_into(|x| { *x = x.exp(); });
// Do E-step
let posterior_weight = do_e_step(data, &exp_beta_z, &lambda);
// Do M-step (skip expensive unnecessary steps)
let s0 = m_step_compute_s0(data, &beta);
let new_lambda = m_step_compute_new_lambda(data, &posterior_weight, &s0);
// Check for convergence
let old_cumulative_hazard = cumulative_hazard(&lambda);
let new_cumulative_hazard = cumulative_hazard(&new_lambda);
let lambda_diff = max_abs_difference(&old_cumulative_hazard, &new_cumulative_hazard);
lambda = new_lambda;
// TODO: Incorporate into progress bar
//println!("Profile iteration {}, estimates changed by {}", iteration + 1, lambda_diff);
if lambda_diff < tolerance {
return log_likelihood_obs(data, &beta, &lambda);
}
}
panic!("Exceeded --max-iterations");
}
#[derive(Serialize, Deserialize)]
2023-04-17 22:12:07 +10:00
pub struct IntervalCensoredCoxResult {
pub params: Vec<f64>,
pub params_se: Vec<f64>,
pub cumulative_hazard: Vec<f64>,
pub cumulative_hazard_times: Vec<f64>,
2023-04-17 22:12:07 +10:00
pub ll_model: f64,
pub ll_null: f64,
2023-04-17 17:50:43 +10:00
// TODO: cumulative hazard, etc.
}
fn cumulative_hazard(lambda: &DVector<f64>) -> DVector<f64> {
let mut result = DVector::zeros(lambda.nrows());
for (i, value) in lambda.iter().enumerate() {
if i > 0 {
result[i] += result[i - 1];
}
result[i] += value;
}
return result;
}
fn max_abs_difference(vector_old: &DVector<f64>, vector_new: &DVector<f64>) -> f64 {
return (vector_new - vector_old).abs().max();
}