Implement turnbull
This commit is contained in:
parent
6c5ab0dd60
commit
67ce046522
@ -1,4 +1,5 @@
|
||||
pub mod intcox;
|
||||
pub mod turnbull;
|
||||
|
||||
mod pava;
|
||||
mod term;
|
||||
|
@ -17,6 +17,7 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use hpstat::intcox;
|
||||
use hpstat::turnbull;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(about="High-performance statistics implementations")]
|
||||
@ -29,6 +30,9 @@ struct MainArgs {
|
||||
enum Command {
|
||||
#[command(name="intcox", about="Interval-censored Cox regression", long_about="Fit a Cox proportional hazards model on time-independent interval-censored observations")]
|
||||
IntCox(intcox::IntCoxArgs),
|
||||
|
||||
#[command(name="turnbull", about="Interval-censored Turnbull survival estimation", long_about="Fit a Turnbull survival estimator on interval-censored observations")]
|
||||
Turnbull(turnbull::TurnbullArgs),
|
||||
}
|
||||
|
||||
fn main() {
|
||||
@ -36,5 +40,6 @@ fn main() {
|
||||
|
||||
match args.command {
|
||||
Command::IntCox(intcox_args) => intcox::main(intcox_args),
|
||||
Command::Turnbull(turnbull_args) => turnbull::main(turnbull_args),
|
||||
}
|
||||
}
|
||||
|
307
src/turnbull.rs
Normal file
307
src/turnbull.rs
Normal file
@ -0,0 +1,307 @@
|
||||
// hpstat: High-performance statistics implementations
|
||||
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
|
||||
|
||||
use core::mem::MaybeUninit;
|
||||
use std::io;
|
||||
|
||||
use clap::{Args, ValueEnum};
|
||||
use csv::{Reader, StringRecord};
|
||||
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressIterator, ProgressStyle};
|
||||
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
|
||||
use prettytable::{Table, format, row};
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::term::UnconditionalTermLike;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct TurnbullArgs {
|
||||
/// Path to CSV input file containing the observations
|
||||
#[arg()]
|
||||
input: String,
|
||||
|
||||
/// Output format
|
||||
#[arg(long, value_enum, default_value="text")]
|
||||
output: OutputFormat,
|
||||
|
||||
/// Maximum number of iterations to attempt
|
||||
#[arg(long, default_value="1000")]
|
||||
max_iterations: u32,
|
||||
|
||||
/// Terminate algorithm when the absolute change in failure probability in each interval is less than this tolerance
|
||||
#[arg(long, default_value="0.0001")]
|
||||
fail_prob_tolerance: f64,
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone)]
|
||||
enum OutputFormat {
|
||||
Text,
|
||||
Json
|
||||
}
|
||||
|
||||
pub fn main(args: TurnbullArgs) {
|
||||
// Read data
|
||||
let data_times = read_data(&args.input);
|
||||
|
||||
// Fit regression
|
||||
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
|
||||
let result = fit_turnbull(data_times, progress_bar, args.max_iterations, args.fail_prob_tolerance);
|
||||
|
||||
// Display output
|
||||
match args.output {
|
||||
OutputFormat::Text => {
|
||||
println!();
|
||||
println!();
|
||||
|
||||
let mut summary = Table::new();
|
||||
let format = format::FormatBuilder::new()
|
||||
.separators(&[format::LinePosition::Top, format::LinePosition::Title, format::LinePosition::Bottom], format::LineSeparator::new('-', '+', '+', '+'))
|
||||
.padding(2, 2)
|
||||
.build();
|
||||
summary.set_format(format);
|
||||
|
||||
summary.set_titles(row!["Time", c->"Surv. Prob.", c->"Std Err.", H2c->"(95% CI)"]);
|
||||
summary.add_row(row![r->"0.000", r->"1.00000", "", "", ""]);
|
||||
for (i, prob) in result.survival_prob.iter().enumerate() {
|
||||
summary.add_row(row![
|
||||
r->format!("{:.3}", result.failure_intervals[i].1),
|
||||
r->format!("{:.5}", prob),
|
||||
r->format!("{:.5}", result.survival_prob_se[i]),
|
||||
r->format!("({:.5},", prob - Z_97_5 * result.survival_prob_se[i]),
|
||||
format!("{:.5})", prob + Z_97_5 * result.survival_prob_se[i]),
|
||||
]);
|
||||
}
|
||||
summary.add_row(row![r->format!("{:.3}", result.failure_intervals.last().unwrap().1), r->"0.00000", "", "", ""]);
|
||||
summary.printstd();
|
||||
}
|
||||
OutputFormat::Json => {
|
||||
println!("{}", serde_json::to_string(&result).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_data(path: &str) -> MatrixXx2<f64> {
|
||||
// Read CSV into memory
|
||||
let _headers: StringRecord;
|
||||
let records: Vec<StringRecord>;
|
||||
if path == "-" {
|
||||
let mut csv_reader = Reader::from_reader(io::stdin());
|
||||
_headers = csv_reader.headers().unwrap().clone();
|
||||
records = csv_reader.records().map(|r| r.unwrap()).collect();
|
||||
} else {
|
||||
let mut csv_reader = Reader::from_path(path).unwrap();
|
||||
_headers = csv_reader.headers().unwrap().clone();
|
||||
records = csv_reader.records().map(|r| r.unwrap()).collect();
|
||||
}
|
||||
|
||||
// Read data into matrices
|
||||
|
||||
let mut data_times: MatrixXx2<MaybeUninit<f64>> = MatrixXx2::uninit(
|
||||
Dyn(records.len()),
|
||||
Const::<2> // Left time, right time
|
||||
);
|
||||
|
||||
// Parse data
|
||||
for (i, row) in records.iter().enumerate() {
|
||||
for (j, item) in row.iter().enumerate() {
|
||||
let value = match item {
|
||||
"inf" => f64::INFINITY,
|
||||
_ => item.parse().expect("Malformed float")
|
||||
};
|
||||
|
||||
data_times[(i, j)].write(value);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Fail on left time > right time
|
||||
// TODO: Fail on left time < 0
|
||||
|
||||
// SAFETY: assume_init is OK because we initialised all values above
|
||||
unsafe {
|
||||
return data_times.assume_init();
|
||||
}
|
||||
}
|
||||
|
||||
struct TurnbullData {
|
||||
data_time_interval_indexes: Vec<(usize, usize)>,
|
||||
|
||||
// Cached intermediate values
|
||||
intervals: Vec<(f64, f64)>,
|
||||
}
|
||||
|
||||
impl TurnbullData {
|
||||
fn num_obs(&self) -> usize {
|
||||
return self.data_time_interval_indexes.len();
|
||||
}
|
||||
|
||||
fn num_intervals(&self) -> usize {
|
||||
return self.intervals.len();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64) -> TurnbullResult {
|
||||
// ----------------------
|
||||
// Prepare for regression
|
||||
|
||||
// Get Turnbull intervals
|
||||
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
|
||||
all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
|
||||
all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true)));
|
||||
all_time_points.dedup();
|
||||
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
|
||||
|
||||
let mut intervals: Vec<(f64, f64)> = Vec::new();
|
||||
for i in 1..all_time_points.len() {
|
||||
if all_time_points[i - 1].1 == true && all_time_points[i].1 == false {
|
||||
intervals.push((all_time_points[i - 1].0, all_time_points[i].0));
|
||||
}
|
||||
}
|
||||
|
||||
// Recode times as indexes
|
||||
let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| {
|
||||
let tleft = t[0];
|
||||
let tright = t[1];
|
||||
|
||||
// Left index is first interval >= observation left bound
|
||||
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
|
||||
|
||||
// Right index is last interval <= observation right bound
|
||||
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
|
||||
|
||||
(left_index, right_index)
|
||||
}).collect();
|
||||
|
||||
// Initialise s
|
||||
let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64);
|
||||
|
||||
let data = TurnbullData {
|
||||
data_time_interval_indexes: data_time_interval_indexes,
|
||||
intervals: intervals,
|
||||
};
|
||||
|
||||
// ------------------------------------------
|
||||
// Apply iterative algorithm to fit estimator
|
||||
|
||||
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap());
|
||||
progress_bar.set_length(u64::MAX);
|
||||
progress_bar.reset();
|
||||
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
|
||||
|
||||
let mut iteration = 1;
|
||||
loop {
|
||||
// Get total failure probability for each observation (denominator of μ_ij)
|
||||
let sum_fail_prob = DVector::from_iterator(
|
||||
data.num_obs(),
|
||||
data.data_time_interval_indexes
|
||||
.iter()
|
||||
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum())
|
||||
);
|
||||
|
||||
// Compute π_j
|
||||
let mut pi: DVector<f64> = DVector::zeros(data.num_intervals());
|
||||
for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() {
|
||||
for j in *idx_left..(*idx_right + 1) {
|
||||
pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64;
|
||||
}
|
||||
}
|
||||
|
||||
let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
|
||||
|
||||
let converged = largest_delta_s <= fail_prob_tolerance;
|
||||
|
||||
s = pi;
|
||||
|
||||
// Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations
|
||||
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
|
||||
let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64;
|
||||
|
||||
// Update progress bar
|
||||
progress_bar.set_position(progress_bar.position().max(progress3.max(progress2)));
|
||||
progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s));
|
||||
|
||||
if converged {
|
||||
progress_bar.println(format!("Converged in {} iterations", iteration));
|
||||
break;
|
||||
}
|
||||
|
||||
iteration += 1;
|
||||
if iteration > max_iterations {
|
||||
panic!("Exceeded --max-iterations");
|
||||
}
|
||||
}
|
||||
|
||||
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
|
||||
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
|
||||
let mut acc = 1.0;
|
||||
for j in 0..(data.num_intervals() - 1) {
|
||||
acc -= s[j];
|
||||
survival_prob.push(acc);
|
||||
}
|
||||
|
||||
// --------------------------------------------------
|
||||
// Compute standard errors for survival probabilities
|
||||
|
||||
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Compute Hessian {pos}/{len}").unwrap());
|
||||
progress_bar.set_length(data.num_obs() as u64);
|
||||
progress_bar.reset();
|
||||
progress_bar.println("Computing standard errors for survival probabilities");
|
||||
|
||||
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
|
||||
|
||||
for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) {
|
||||
let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
|
||||
hessian_denominator = hessian_denominator.powi(2);
|
||||
|
||||
let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case
|
||||
let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian
|
||||
|
||||
for h in idx_start..idx_end {
|
||||
let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 };
|
||||
let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 };
|
||||
|
||||
hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator;
|
||||
|
||||
for k in idx_start..h {
|
||||
let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 };
|
||||
let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 };
|
||||
|
||||
let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator;
|
||||
hessian[(h, k)] -= value;
|
||||
hessian[(k, h)] -= value;
|
||||
}
|
||||
}
|
||||
}
|
||||
progress_bar.finish();
|
||||
|
||||
let vcov = (-hessian).try_inverse().expect("Matrix not invertible");
|
||||
let survival_prob_se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); });
|
||||
|
||||
return TurnbullResult {
|
||||
failure_intervals: data.intervals,
|
||||
failure_prob: s.data.as_vec().clone(),
|
||||
survival_prob: survival_prob,
|
||||
survival_prob_se: survival_prob_se.data.as_vec().clone(),
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TurnbullResult {
|
||||
pub failure_intervals: Vec<(f64, f64)>,
|
||||
pub failure_prob: Vec<f64>,
|
||||
pub survival_prob: Vec<f64>,
|
||||
pub survival_prob_se: Vec<f64>,
|
||||
}
|
Loading…
Reference in New Issue
Block a user