hpstat/src/turnbull.rs

427 lines
14 KiB
Rust

// hpstat: High-performance statistics implementations
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
use core::mem::MaybeUninit;
use std::io;
use clap::{Args, ValueEnum};
use csv::{Reader, StringRecord};
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
use prettytable::{Table, format, row};
use rayon::prelude::*;
use serde::{Serialize, Deserialize};
use crate::term::UnconditionalTermLike;
#[derive(Args)]
pub struct TurnbullArgs {
/// Path to CSV input file containing the observations
#[arg()]
input: String,
/// Output format
#[arg(long, value_enum, default_value="text")]
output: OutputFormat,
/// Maximum number of iterations to attempt
#[arg(long, default_value="1000")]
max_iterations: u32,
/// Terminate algorithm when the absolute change in log-likelihood is less than this tolerance
#[arg(long, default_value="0.01")]
ll_tolerance: f64,
/// Method for computing standard error or survival probabilities
#[arg(long, value_enum, default_value="oim")]
se_method: SEMethod,
/// Threshold for dropping failure probability in --se-method oim-drop-zeros
#[arg(long, default_value="0.0001")]
zero_tolerance: f64,
}
#[derive(ValueEnum, Clone)]
enum OutputFormat {
Text,
Json
}
#[derive(ValueEnum, Clone)]
pub enum SEMethod {
OIM,
OIMDropZeros,
}
pub fn main(args: TurnbullArgs) {
// Read data
let data_times = read_data(&args.input);
// Fit regression
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.ll_tolerance, args.se_method, args.zero_tolerance);
// Display output
match args.output {
OutputFormat::Text => {
println!();
println!();
println!("LL = {:.5}", result.ll_model);
let mut summary = Table::new();
let format = format::FormatBuilder::new()
.separators(&[format::LinePosition::Top, format::LinePosition::Title, format::LinePosition::Bottom], format::LineSeparator::new('-', '+', '+', '+'))
.padding(2, 2)
.build();
summary.set_format(format);
summary.set_titles(row!["Time", c->"Surv. Prob.", c->"Std Err.", H2c->"(95% CI)"]);
summary.add_row(row![r->"0.000", r->"1.00000", "", "", ""]);
for (i, prob) in result.survival_prob.iter().enumerate() {
summary.add_row(row![
r->format!("{:.3}", result.failure_intervals[i].1),
r->format!("{:.5}", prob),
r->format!("{:.5}", result.survival_prob_se[i]),
r->format!("({:.5},", prob - Z_97_5 * result.survival_prob_se[i]),
format!("{:.5})", prob + Z_97_5 * result.survival_prob_se[i]),
]);
}
summary.add_row(row![r->format!("{:.3}", result.failure_intervals.last().unwrap().1), r->"0.00000", "", "", ""]);
summary.printstd();
}
OutputFormat::Json => {
println!("{}", serde_json::to_string(&result).unwrap());
}
}
}
pub fn read_data(path: &str) -> MatrixXx2<f64> {
// Read CSV into memory
let _headers: StringRecord;
let records: Vec<StringRecord>;
if path == "-" {
let mut csv_reader = Reader::from_reader(io::stdin());
_headers = csv_reader.headers().unwrap().clone();
records = csv_reader.records().map(|r| r.unwrap()).collect();
} else {
let mut csv_reader = Reader::from_path(path).unwrap();
_headers = csv_reader.headers().unwrap().clone();
records = csv_reader.records().map(|r| r.unwrap()).collect();
}
// Read data into matrices
let mut data_times: MatrixXx2<MaybeUninit<f64>> = MatrixXx2::uninit(
Dyn(records.len()),
Const::<2> // Left time, right time
);
// Parse data
for (i, row) in records.iter().enumerate() {
for (j, item) in row.iter().enumerate() {
let value = match item {
"inf" => f64::INFINITY,
_ => item.parse().expect("Malformed float")
};
data_times[(i, j)].write(value);
}
}
// TODO: Fail on left time > right time
// TODO: Fail on left time < 0
// SAFETY: assume_init is OK because we initialised all values above
unsafe {
return data_times.assume_init();
}
}
struct TurnbullData {
data_time_interval_indexes: Vec<(usize, usize)>,
// Cached intermediate values
intervals: Vec<(f64, f64)>,
}
impl TurnbullData {
fn num_obs(&self) -> usize {
return self.data_time_interval_indexes.len();
}
fn num_intervals(&self) -> usize {
return self.intervals.len();
}
}
pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
// ----------------------
// Prepare for regression
// Get Turnbull intervals
let intervals = get_turnbull_intervals(&data_times);
// Recode times as indexes
let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| {
let tleft = t[0];
let tright = t[1];
// Left index is first interval >= observation left bound
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
// Right index is last interval <= observation right bound
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
(left_index, right_index)
}).collect();
// Initialise s
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
let s = vec![1.0 / intervals.len() as f64; intervals.len()];
let mut data = TurnbullData {
data_time_interval_indexes: data_time_interval_indexes,
intervals: intervals,
};
// ------------------------------------------
// Apply iterative algorithm to fit estimator
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap());
progress_bar.set_length(u64::MAX);
progress_bar.reset();
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
let (s, ll) = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, ll_tolerance, s);
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
let mut acc = 1.0;
for j in 0..(data.num_intervals() - 1) {
acc -= s[j];
survival_prob.push(acc);
}
// --------------------------------------------------
// Compute standard errors for survival probabilities
let hessian = compute_hessian(&data, &s);
let mut survival_prob_se: DVector<f64>;
match se_method {
SEMethod::OIM => {
// Compute covariance matrix as inverse of negative Hessian
let vcov = -hessian.try_inverse().expect("Matrix not invertible");
survival_prob_se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); });
}
SEMethod::OIMDropZeros => {
// Drop rows/columns of Hessian corresponding to intervals with zero failure probability
let nonzero_intervals: Vec<usize> = (0..(data.num_intervals() - 1)).filter(|i| s[*i] > zero_tolerance).collect();
let mut hessian_nonzero: DMatrix<f64> = DMatrix::zeros(nonzero_intervals.len(), nonzero_intervals.len());
for (nonzero_index1, orig_index1) in nonzero_intervals.iter().enumerate() {
hessian_nonzero[(nonzero_index1, nonzero_index1)] = hessian[(*orig_index1, *orig_index1)];
for (nonzero_index2, orig_index2) in nonzero_intervals.iter().enumerate().take(nonzero_index1) {
hessian_nonzero[(nonzero_index1, nonzero_index2)] = hessian[(*orig_index1, *orig_index2)];
hessian_nonzero[(nonzero_index2, nonzero_index1)] = hessian[(*orig_index2, *orig_index1)];
}
}
let vcov = -hessian_nonzero.try_inverse().expect("Matrix not invertible");
let survival_prob_se_nonzero = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); });
survival_prob_se = DVector::zeros(data.num_intervals() - 1);
let mut nonzero_index = 0;
for orig_index in 0..(data.num_intervals() - 1) {
if nonzero_intervals.contains(&orig_index) {
survival_prob_se[orig_index] = survival_prob_se_nonzero[nonzero_index];
nonzero_index += 1;
} else {
survival_prob_se[orig_index] = survival_prob_se[orig_index - 1];
}
}
}
}
return TurnbullResult {
failure_intervals: data.intervals,
failure_prob: s,
survival_prob: survival_prob,
survival_prob_se: survival_prob_se.data.as_vec().clone(),
ll_model: ll,
};
}
fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true)));
all_time_points.dedup();
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
let mut intervals: Vec<(f64, f64)> = Vec::new();
for i in 1..all_time_points.len() {
if all_time_points[i - 1].1 == true && all_time_points[i].1 == false {
intervals.push((all_time_points[i - 1].0, all_time_points[i].0));
}
}
return intervals;
}
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut s: Vec<f64>) -> (Vec<f64>, f64) {
// Get likelihood for each observation (denominator of μ_ij)
let mut likelihood_obs = get_likelihood_obs(data, &s);
let mut ll_model: f64 = likelihood_obs.iter().map(|l| l.ln()).sum();
let mut iteration = 1;
loop {
// Compute π_j to update s
let pi = compute_pi(data, &s, likelihood_obs);
let likelihood_obs_new = get_likelihood_obs(data, &pi);
let ll_model_new = likelihood_obs_new.iter().map(|l| l.ln()).sum();
let ll_change = ll_model_new - ll_model;
let converged = ll_change <= ll_tolerance;
s = pi;
likelihood_obs = likelihood_obs_new;
ll_model = ll_model_new;
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance.log10() * u64::MAX as f64) as u64;
// Update progress bar
progress_bar.set_position(progress_bar.position().max(progress3.max(progress2)));
progress_bar.set_message(format!("Iteration {} (LL = {:.4}, ΔLL = {:.4})", iteration + 1, ll_model, ll_change));
if converged {
progress_bar.println(format!("Converged in {} iterations", iteration));
break;
}
iteration += 1;
if iteration > max_iterations {
panic!("Exceeded --max-iterations");
}
}
return (s, ll_model);
}
fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
return data.data_time_interval_indexes
.par_iter()
.map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum())
.collect();
}
fn compute_pi(data: &TurnbullData, s: &Vec<f64>, likelihood_obs: Vec<f64>) -> Vec<f64> {
/*
let mut pi: Vec<f64> = vec![0.0; data.num_intervals()];
for ((idx_left, idx_right), likelihood_obs_i) in data.data_time_interval_indexes.iter().zip(likelihood_obs.iter()) {
for j in *idx_left..(*idx_right + 1) {
pi[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
}
}
*/
let pi = data.data_time_interval_indexes.par_iter().zip(likelihood_obs.par_iter())
.fold_with(
// Compute the contributions to pi[j] for each observation and sum them in parallel using fold_with
vec![0.0; data.num_intervals()],
|mut acc, ((idx_left, idx_right), likelihood_obs_i)| {
// Contributions to pi[j] for the i-th observation
for j in *idx_left..(*idx_right + 1) {
acc[j] += s[j] / likelihood_obs_i / data.num_obs() as f64;
}
acc
}
)
.reduce(
// Reduce all the sub-sums from fold_with into the total sum
|| vec![0.0; data.num_intervals()],
|mut acc, subsum| {
acc.iter_mut().zip(subsum.iter()).for_each(|(x, y)| *x += y);
acc
}
);
return pi;
}
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
// Compute 1 / (Σ_j α_{i,j} s_j)
let mut one_over_hessian_denominator: f64 = s[*idx_left..(*idx_right + 1)].iter().sum();
one_over_hessian_denominator = one_over_hessian_denominator.powi(-2);
// The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1})
// This is nonzero only when α_{i,h} ≠ α_{i,h+1} AND α_{i,k} ≠ α_{i,k+1}
// Since each observation spans a continuous sequence of intervals, this is true only at two each of h and k at the boundaries of the observation
// h = last interval not involving the observation, h + 1 = first interval involving the observation, etc.
// if *idx_left > 0 { h1 = idx_left - 1; }
// if *idx_right < data.num_intervals() - 1 { h2 = *idx_right; }
if *idx_left > 0 {
let h1 = idx_left - 1;
// (h, k) = (h1, h1)
// numerator is -(0 - 1)(0 - 1) = -1
hessian[(h1, h1)] -= one_over_hessian_denominator;
}
if *idx_right < data.num_intervals() - 1 {
let h2 = *idx_right;
// (h, k) = (h2, h2)
// numerator is -(1 - 0)(1 - 0) = -1
hessian[(h2, h2)] -= one_over_hessian_denominator;
if *idx_left > 0 {
let h1 = idx_left - 1;
// (h, k) = (h1, h2)
// numerator is -(0 - 1)(1 - 0) = 1
hessian[(h1, h2)] += one_over_hessian_denominator;
// (h, k) = (h2, h1)
// numerator is -(1 - 0)(0 - 1) = 1
hessian[(h2, h1)] += one_over_hessian_denominator;
}
}
}
return hessian;
}
#[derive(Serialize, Deserialize)]
pub struct TurnbullResult {
pub failure_intervals: Vec<(f64, f64)>,
pub failure_prob: Vec<f64>,
pub survival_prob: Vec<f64>,
pub survival_prob_se: Vec<f64>,
pub ll_model: f64,
}