From b691c5a8d71a11a928fd87aa80d4ba69b35416a8 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Thu, 9 Nov 2023 23:33:42 +1100 Subject: [PATCH] turnbull: Parallelise recoding times as indexes 11% speedup --- Cargo.lock | 1 + Cargo.toml | 4 ++-- src/turnbull.rs | 49 +++++++++++++++++++++++++++---------------------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4c99a24..978331f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -418,6 +418,7 @@ dependencies = [ "num-complex", "num-rational", "num-traits", + "rayon", "simba", "typenum", ] diff --git a/Cargo.toml b/Cargo.toml index 72d1c63..5ddc5b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ edition = "2021" clap = { version = "4.2.1", features = ["derive"] } console = "0.15.5" csv = "1.2.1" -indicatif = {version = "0.17.3", features = ["rayon"]} -nalgebra = "0.32.2" +indicatif = { version = "0.17.3", features = ["rayon"] } +nalgebra = { version = "0.32.2", features = ["rayon"] } prettytable-rs = "0.10.0" rayon = "1.7.0" serde = { version = "1.0.160", features = ["derive"] } diff --git a/src/turnbull.rs b/src/turnbull.rs index d220247..0419e07 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -23,7 +23,7 @@ use std::io; use clap::{Args, ValueEnum}; use csv::{Reader, StringRecord}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; -use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2}; +use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX}; use prettytable::{Table, format, row}; use rayon::prelude::*; use serde::{Serialize, Deserialize}; @@ -132,7 +132,7 @@ pub fn main(args: TurnbullArgs) { } } -pub fn read_data(path: &str) -> MatrixXx2 { +pub fn read_data(path: &str) -> Matrix2xX { // Read CSV into memory let _headers: StringRecord; let records: Vec; @@ -148,9 +148,10 @@ pub fn read_data(path: &str) -> MatrixXx2 { // Read data into matrices - let mut data_times: MatrixXx2> = MatrixXx2::uninit( - Dyn(records.len()), - Const::<2> // Left time, right time + // Represent data_times as 2xX rather than Xx2 matrix to allow par_column_iter in code_times_as_indexes (no par_row_iter) + let mut data_times: Matrix2xX> = Matrix2xX::uninit( + Const::<2>, // Left time, right time + Dyn(records.len()) ); // Parse data @@ -161,7 +162,7 @@ pub fn read_data(path: &str) -> MatrixXx2 { _ => item.parse().expect("Malformed float") }; - data_times[(i, j)].write(value); + data_times[(j, i)].write(value); } } @@ -197,7 +198,7 @@ struct Constraint { survival_prob: f64, } -pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult { +pub fn fit_turnbull(data_times: Matrix2xX, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult { // ---------------------- // Prepare for regression @@ -205,18 +206,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i let intervals = get_turnbull_intervals(&data_times); // Recode times as indexes - let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| { - let tleft = t[0]; - let tright = t[1]; - - // Left index is first interval >= observation left bound - let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0; - - // Right index is last interval <= observation right bound - let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0; - - (left_index, right_index) - }).collect(); + let data_time_interval_indexes = code_times_as_indexes(&data_times, &intervals); // Initialise p // Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec @@ -286,10 +276,10 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i }; } -fn get_turnbull_intervals(data_times: &MatrixXx2) -> Vec<(f64, f64)> { +fn get_turnbull_intervals(data_times: &Matrix2xX) -> Vec<(f64, f64)> { let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left) - all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open - all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true))); + all_time_points.extend(data_times.row(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open + all_time_points.extend(data_times.row(0).iter().map(|t| (*t, true))); all_time_points.dedup(); all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap()); @@ -303,6 +293,21 @@ fn get_turnbull_intervals(data_times: &MatrixXx2) -> Vec<(f64, f64)> { return intervals; } +fn code_times_as_indexes(data_times: &Matrix2xX, intervals: &Vec<(f64, f64)>) -> Vec<(usize, usize)> { + return data_times.par_column_iter().map(|t| { + let tleft = t[0]; + let tright = t[1]; + + // Left index is first interval >= observation left bound + let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0; + + // Right index is last interval <= observation right bound + let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0; + + (left_index, right_index) + }).collect(); +} + fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec, constraint: Option) -> (Vec, f64) { // Pre-compute S, the survival probability at the start of each interval let mut s = p_to_s(&p);