diff --git a/src/lib.rs b/src/lib.rs index 0c6eb20..73933ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod intcox; +pub mod turnbull; mod pava; mod term; diff --git a/src/main.rs b/src/main.rs index 4938ac9..40b1f67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,7 @@ use clap::{Parser, Subcommand}; use hpstat::intcox; +use hpstat::turnbull; #[derive(Parser)] #[command(about="High-performance statistics implementations")] @@ -29,6 +30,9 @@ struct MainArgs { enum Command { #[command(name="intcox", about="Interval-censored Cox regression", long_about="Fit a Cox proportional hazards model on time-independent interval-censored observations")] IntCox(intcox::IntCoxArgs), + + #[command(name="turnbull", about="Interval-censored Turnbull survival estimation", long_about="Fit a Turnbull survival estimator on interval-censored observations")] + Turnbull(turnbull::TurnbullArgs), } fn main() { @@ -36,5 +40,6 @@ fn main() { match args.command { Command::IntCox(intcox_args) => intcox::main(intcox_args), + Command::Turnbull(turnbull_args) => turnbull::main(turnbull_args), } } diff --git a/src/turnbull.rs b/src/turnbull.rs new file mode 100644 index 0000000..6d6ac67 --- /dev/null +++ b/src/turnbull.rs @@ -0,0 +1,307 @@ +// hpstat: High-performance statistics implementations +// Copyright © 2023 Lee Yingtong Li (RunasSudo) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64 + +use core::mem::MaybeUninit; +use std::io; + +use clap::{Args, ValueEnum}; +use csv::{Reader, StringRecord}; +use indicatif::{ProgressBar, ProgressDrawTarget, ProgressIterator, ProgressStyle}; +use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2}; +use prettytable::{Table, format, row}; +use serde::{Serialize, Deserialize}; + +use crate::term::UnconditionalTermLike; + +#[derive(Args)] +pub struct TurnbullArgs { + /// Path to CSV input file containing the observations + #[arg()] + input: String, + + /// Output format + #[arg(long, value_enum, default_value="text")] + output: OutputFormat, + + /// Maximum number of iterations to attempt + #[arg(long, default_value="1000")] + max_iterations: u32, + + /// Terminate algorithm when the absolute change in failure probability in each interval is less than this tolerance + #[arg(long, default_value="0.0001")] + fail_prob_tolerance: f64, +} + +#[derive(ValueEnum, Clone)] +enum OutputFormat { + Text, + Json +} + +pub fn main(args: TurnbullArgs) { + // Read data + let data_times = read_data(&args.input); + + // Fit regression + let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); + let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.fail_prob_tolerance); + + // Display output + match args.output { + OutputFormat::Text => { + println!(); + println!(); + + let mut summary = Table::new(); + let format = format::FormatBuilder::new() + .separators(&[format::LinePosition::Top, format::LinePosition::Title, format::LinePosition::Bottom], format::LineSeparator::new('-', '+', '+', '+')) + .padding(2, 2) + .build(); + summary.set_format(format); + + summary.set_titles(row!["Time", c->"Surv. Prob.", c->"Std Err.", H2c->"(95% CI)"]); + summary.add_row(row![r->"0.000", r->"1.00000", "", "", ""]); + for (i, prob) in result.survival_prob.iter().enumerate() { + summary.add_row(row![ + r->format!("{:.3}", result.failure_intervals[i].1), + r->format!("{:.5}", prob), + r->format!("{:.5}", result.survival_prob_se[i]), + r->format!("({:.5},", prob - Z_97_5 * result.survival_prob_se[i]), + format!("{:.5})", prob + Z_97_5 * result.survival_prob_se[i]), + ]); + } + summary.add_row(row![r->format!("{:.3}", result.failure_intervals.last().unwrap().1), r->"0.00000", "", "", ""]); + summary.printstd(); + } + OutputFormat::Json => { + println!("{}", serde_json::to_string(&result).unwrap()); + } + } +} + +pub fn read_data(path: &str) -> MatrixXx2 { + // Read CSV into memory + let _headers: StringRecord; + let records: Vec; + if path == "-" { + let mut csv_reader = Reader::from_reader(io::stdin()); + _headers = csv_reader.headers().unwrap().clone(); + records = csv_reader.records().map(|r| r.unwrap()).collect(); + } else { + let mut csv_reader = Reader::from_path(path).unwrap(); + _headers = csv_reader.headers().unwrap().clone(); + records = csv_reader.records().map(|r| r.unwrap()).collect(); + } + + // Read data into matrices + + let mut data_times: MatrixXx2> = MatrixXx2::uninit( + Dyn(records.len()), + Const::<2> // Left time, right time + ); + + // Parse data + for (i, row) in records.iter().enumerate() { + for (j, item) in row.iter().enumerate() { + let value = match item { + "inf" => f64::INFINITY, + _ => item.parse().expect("Malformed float") + }; + + data_times[(i, j)].write(value); + } + } + + // TODO: Fail on left time > right time + // TODO: Fail on left time < 0 + + // SAFETY: assume_init is OK because we initialised all values above + unsafe { + return data_times.assume_init(); + } +} + +struct TurnbullData { + data_time_interval_indexes: Vec<(usize, usize)>, + + // Cached intermediate values + intervals: Vec<(f64, f64)>, +} + +impl TurnbullData { + fn num_obs(&self) -> usize { + return self.data_time_interval_indexes.len(); + } + + fn num_intervals(&self) -> usize { + return self.intervals.len(); + } +} + +pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64) -> TurnbullResult { + // ---------------------- + // Prepare for regression + + // Get Turnbull intervals + let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left) + all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open + all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true))); + all_time_points.dedup(); + all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap()); + + let mut intervals: Vec<(f64, f64)> = Vec::new(); + for i in 1..all_time_points.len() { + if all_time_points[i - 1].1 == true && all_time_points[i].1 == false { + intervals.push((all_time_points[i - 1].0, all_time_points[i].0)); + } + } + + // Recode times as indexes + let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| { + let tleft = t[0]; + let tright = t[1]; + + // Left index is first interval >= observation left bound + let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0; + + // Right index is last interval <= observation right bound + let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0; + + (left_index, right_index) + }).collect(); + + // Initialise s + let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64); + + let data = TurnbullData { + data_time_interval_indexes: data_time_interval_indexes, + intervals: intervals, + }; + + // ------------------------------------------ + // Apply iterative algorithm to fit estimator + + progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap()); + progress_bar.set_length(u64::MAX); + progress_bar.reset(); + progress_bar.println("Running iterative algorithm to fit Turnbull estimator"); + + let mut iteration = 1; + loop { + // Get total failure probability for each observation (denominator of μ_ij) + let sum_fail_prob = DVector::from_iterator( + data.num_obs(), + data.data_time_interval_indexes + .iter() + .map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum()) + ); + + // Compute π_j + let mut pi: DVector = DVector::zeros(data.num_intervals()); + for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() { + for j in *idx_left..(*idx_right + 1) { + pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64; + } + } + + let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); + + let converged = largest_delta_s <= fail_prob_tolerance; + + s = pi; + + // Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations + let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64; + let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64; + + // Update progress bar + progress_bar.set_position(progress_bar.position().max(progress3.max(progress2))); + progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s)); + + if converged { + progress_bar.println(format!("Converged in {} iterations", iteration)); + break; + } + + iteration += 1; + if iteration > max_iterations { + panic!("Exceeded --max-iterations"); + } + } + + // Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0) + let mut survival_prob: Vec = Vec::with_capacity(data.num_intervals() - 1); + let mut acc = 1.0; + for j in 0..(data.num_intervals() - 1) { + acc -= s[j]; + survival_prob.push(acc); + } + + // -------------------------------------------------- + // Compute standard errors for survival probabilities + + progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Compute Hessian {pos}/{len}").unwrap()); + progress_bar.set_length(data.num_obs() as u64); + progress_bar.reset(); + progress_bar.println("Computing standard errors for survival probabilities"); + + let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); + + for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) { + let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum(); + hessian_denominator = hessian_denominator.powi(2); + + let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case + let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian + + for h in idx_start..idx_end { + let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 }; + let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 }; + + hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator; + + for k in idx_start..h { + let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 }; + let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 }; + + let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator; + hessian[(h, k)] -= value; + hessian[(k, h)] -= value; + } + } + } + progress_bar.finish(); + + let vcov = (-hessian).try_inverse().expect("Matrix not invertible"); + let survival_prob_se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); }); + + return TurnbullResult { + failure_intervals: data.intervals, + failure_prob: s.data.as_vec().clone(), + survival_prob: survival_prob, + survival_prob_se: survival_prob_se.data.as_vec().clone(), + }; +} + +#[derive(Serialize, Deserialize)] +pub struct TurnbullResult { + pub failure_intervals: Vec<(f64, f64)>, + pub failure_prob: Vec, + pub survival_prob: Vec, + pub survival_prob_se: Vec, +}