turnbull: Parallelise recoding times as indexes
11% speedup
This commit is contained in:
parent
b23ff26eac
commit
b691c5a8d7
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -418,6 +418,7 @@ dependencies = [
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"rayon",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
@ -7,8 +7,8 @@ edition = "2021"
|
||||
clap = { version = "4.2.1", features = ["derive"] }
|
||||
console = "0.15.5"
|
||||
csv = "1.2.1"
|
||||
indicatif = {version = "0.17.3", features = ["rayon"]}
|
||||
nalgebra = "0.32.2"
|
||||
indicatif = { version = "0.17.3", features = ["rayon"] }
|
||||
nalgebra = { version = "0.32.2", features = ["rayon"] }
|
||||
prettytable-rs = "0.10.0"
|
||||
rayon = "1.7.0"
|
||||
serde = { version = "1.0.160", features = ["derive"] }
|
||||
|
@ -23,7 +23,7 @@ use std::io;
|
||||
use clap::{Args, ValueEnum};
|
||||
use csv::{Reader, StringRecord};
|
||||
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
|
||||
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
|
||||
use prettytable::{Table, format, row};
|
||||
use rayon::prelude::*;
|
||||
use serde::{Serialize, Deserialize};
|
||||
@ -132,7 +132,7 @@ pub fn main(args: TurnbullArgs) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_data(path: &str) -> MatrixXx2<f64> {
|
||||
pub fn read_data(path: &str) -> Matrix2xX<f64> {
|
||||
// Read CSV into memory
|
||||
let _headers: StringRecord;
|
||||
let records: Vec<StringRecord>;
|
||||
@ -148,9 +148,10 @@ pub fn read_data(path: &str) -> MatrixXx2<f64> {
|
||||
|
||||
// Read data into matrices
|
||||
|
||||
let mut data_times: MatrixXx2<MaybeUninit<f64>> = MatrixXx2::uninit(
|
||||
Dyn(records.len()),
|
||||
Const::<2> // Left time, right time
|
||||
// Represent data_times as 2xX rather than Xx2 matrix to allow par_column_iter in code_times_as_indexes (no par_row_iter)
|
||||
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
|
||||
Const::<2>, // Left time, right time
|
||||
Dyn(records.len())
|
||||
);
|
||||
|
||||
// Parse data
|
||||
@ -161,7 +162,7 @@ pub fn read_data(path: &str) -> MatrixXx2<f64> {
|
||||
_ => item.parse().expect("Malformed float")
|
||||
};
|
||||
|
||||
data_times[(i, j)].write(value);
|
||||
data_times[(j, i)].write(value);
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,7 +198,7 @@ struct Constraint {
|
||||
survival_prob: f64,
|
||||
}
|
||||
|
||||
pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
|
||||
pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
|
||||
// ----------------------
|
||||
// Prepare for regression
|
||||
|
||||
@ -205,18 +206,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
||||
let intervals = get_turnbull_intervals(&data_times);
|
||||
|
||||
// Recode times as indexes
|
||||
let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| {
|
||||
let tleft = t[0];
|
||||
let tright = t[1];
|
||||
|
||||
// Left index is first interval >= observation left bound
|
||||
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
|
||||
|
||||
// Right index is last interval <= observation right bound
|
||||
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
|
||||
|
||||
(left_index, right_index)
|
||||
}).collect();
|
||||
let data_time_interval_indexes = code_times_as_indexes(&data_times, &intervals);
|
||||
|
||||
// Initialise p
|
||||
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
|
||||
@ -286,10 +276,10 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
||||
};
|
||||
}
|
||||
|
||||
fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
||||
fn get_turnbull_intervals(data_times: &Matrix2xX<f64>) -> Vec<(f64, f64)> {
|
||||
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
|
||||
all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
|
||||
all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true)));
|
||||
all_time_points.extend(data_times.row(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
|
||||
all_time_points.extend(data_times.row(0).iter().map(|t| (*t, true)));
|
||||
all_time_points.dedup();
|
||||
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
|
||||
|
||||
@ -303,6 +293,21 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
||||
return intervals;
|
||||
}
|
||||
|
||||
fn code_times_as_indexes(data_times: &Matrix2xX<f64>, intervals: &Vec<(f64, f64)>) -> Vec<(usize, usize)> {
|
||||
return data_times.par_column_iter().map(|t| {
|
||||
let tleft = t[0];
|
||||
let tright = t[1];
|
||||
|
||||
// Left index is first interval >= observation left bound
|
||||
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
|
||||
|
||||
// Right index is last interval <= observation right bound
|
||||
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
|
||||
|
||||
(left_index, right_index)
|
||||
}).collect();
|
||||
}
|
||||
|
||||
fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>, constraint: Option<Constraint>) -> (Vec<f64>, f64) {
|
||||
// Pre-compute S, the survival probability at the start of each interval
|
||||
let mut s = p_to_s(&p);
|
||||
|
Loading…
Reference in New Issue
Block a user