turnbull: Parallelise recoding times as indexes

11% speedup
This commit is contained in:
RunasSudo 2023-11-09 23:33:42 +11:00
parent b23ff26eac
commit b691c5a8d7
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
3 changed files with 30 additions and 24 deletions

1
Cargo.lock generated
View File

@ -418,6 +418,7 @@ dependencies = [
"num-complex",
"num-rational",
"num-traits",
"rayon",
"simba",
"typenum",
]

View File

@ -7,8 +7,8 @@ edition = "2021"
clap = { version = "4.2.1", features = ["derive"] }
console = "0.15.5"
csv = "1.2.1"
indicatif = {version = "0.17.3", features = ["rayon"]}
nalgebra = "0.32.2"
indicatif = { version = "0.17.3", features = ["rayon"] }
nalgebra = { version = "0.32.2", features = ["rayon"] }
prettytable-rs = "0.10.0"
rayon = "1.7.0"
serde = { version = "1.0.160", features = ["derive"] }

View File

@ -23,7 +23,7 @@ use std::io;
use clap::{Args, ValueEnum};
use csv::{Reader, StringRecord};
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
use prettytable::{Table, format, row};
use rayon::prelude::*;
use serde::{Serialize, Deserialize};
@ -132,7 +132,7 @@ pub fn main(args: TurnbullArgs) {
}
}
pub fn read_data(path: &str) -> MatrixXx2<f64> {
pub fn read_data(path: &str) -> Matrix2xX<f64> {
// Read CSV into memory
let _headers: StringRecord;
let records: Vec<StringRecord>;
@ -148,9 +148,10 @@ pub fn read_data(path: &str) -> MatrixXx2<f64> {
// Read data into matrices
let mut data_times: MatrixXx2<MaybeUninit<f64>> = MatrixXx2::uninit(
Dyn(records.len()),
Const::<2> // Left time, right time
// Represent data_times as 2xX rather than Xx2 matrix to allow par_column_iter in code_times_as_indexes (no par_row_iter)
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
Const::<2>, // Left time, right time
Dyn(records.len())
);
// Parse data
@ -161,7 +162,7 @@ pub fn read_data(path: &str) -> MatrixXx2<f64> {
_ => item.parse().expect("Malformed float")
};
data_times[(i, j)].write(value);
data_times[(j, i)].write(value);
}
}
@ -197,7 +198,7 @@ struct Constraint {
survival_prob: f64,
}
pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
// ----------------------
// Prepare for regression
@ -205,18 +206,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
let intervals = get_turnbull_intervals(&data_times);
// Recode times as indexes
let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| {
let tleft = t[0];
let tright = t[1];
// Left index is first interval >= observation left bound
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
// Right index is last interval <= observation right bound
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
(left_index, right_index)
}).collect();
let data_time_interval_indexes = code_times_as_indexes(&data_times, &intervals);
// Initialise p
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
@ -286,10 +276,10 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
};
}
fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
fn get_turnbull_intervals(data_times: &Matrix2xX<f64>) -> Vec<(f64, f64)> {
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true)));
all_time_points.extend(data_times.row(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
all_time_points.extend(data_times.row(0).iter().map(|t| (*t, true)));
all_time_points.dedup();
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
@ -303,6 +293,21 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
return intervals;
}
fn code_times_as_indexes(data_times: &Matrix2xX<f64>, intervals: &Vec<(f64, f64)>) -> Vec<(usize, usize)> {
return data_times.par_column_iter().map(|t| {
let tleft = t[0];
let tright = t[1];
// Left index is first interval >= observation left bound
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
// Right index is last interval <= observation right bound
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
(left_index, right_index)
}).collect();
}
fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>, constraint: Option<Constraint>) -> (Vec<f64>, f64) {
// Pre-compute S, the survival probability at the start of each interval
let mut s = p_to_s(&p);