turnbull: Parallelise recoding times as indexes
11% speedup
This commit is contained in:
parent
b23ff26eac
commit
b691c5a8d7
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -418,6 +418,7 @@ dependencies = [
|
|||||||
"num-complex",
|
"num-complex",
|
||||||
"num-rational",
|
"num-rational",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
|
"rayon",
|
||||||
"simba",
|
"simba",
|
||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
@ -8,7 +8,7 @@ clap = { version = "4.2.1", features = ["derive"] }
|
|||||||
console = "0.15.5"
|
console = "0.15.5"
|
||||||
csv = "1.2.1"
|
csv = "1.2.1"
|
||||||
indicatif = { version = "0.17.3", features = ["rayon"] }
|
indicatif = { version = "0.17.3", features = ["rayon"] }
|
||||||
nalgebra = "0.32.2"
|
nalgebra = { version = "0.32.2", features = ["rayon"] }
|
||||||
prettytable-rs = "0.10.0"
|
prettytable-rs = "0.10.0"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
serde = { version = "1.0.160", features = ["derive"] }
|
serde = { version = "1.0.160", features = ["derive"] }
|
||||||
|
@ -23,7 +23,7 @@ use std::io;
|
|||||||
use clap::{Args, ValueEnum};
|
use clap::{Args, ValueEnum};
|
||||||
use csv::{Reader, StringRecord};
|
use csv::{Reader, StringRecord};
|
||||||
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||||
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
|
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
|
||||||
use prettytable::{Table, format, row};
|
use prettytable::{Table, format, row};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
@ -132,7 +132,7 @@ pub fn main(args: TurnbullArgs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_data(path: &str) -> MatrixXx2<f64> {
|
pub fn read_data(path: &str) -> Matrix2xX<f64> {
|
||||||
// Read CSV into memory
|
// Read CSV into memory
|
||||||
let _headers: StringRecord;
|
let _headers: StringRecord;
|
||||||
let records: Vec<StringRecord>;
|
let records: Vec<StringRecord>;
|
||||||
@ -148,9 +148,10 @@ pub fn read_data(path: &str) -> MatrixXx2<f64> {
|
|||||||
|
|
||||||
// Read data into matrices
|
// Read data into matrices
|
||||||
|
|
||||||
let mut data_times: MatrixXx2<MaybeUninit<f64>> = MatrixXx2::uninit(
|
// Represent data_times as 2xX rather than Xx2 matrix to allow par_column_iter in code_times_as_indexes (no par_row_iter)
|
||||||
Dyn(records.len()),
|
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
|
||||||
Const::<2> // Left time, right time
|
Const::<2>, // Left time, right time
|
||||||
|
Dyn(records.len())
|
||||||
);
|
);
|
||||||
|
|
||||||
// Parse data
|
// Parse data
|
||||||
@ -161,7 +162,7 @@ pub fn read_data(path: &str) -> MatrixXx2<f64> {
|
|||||||
_ => item.parse().expect("Malformed float")
|
_ => item.parse().expect("Malformed float")
|
||||||
};
|
};
|
||||||
|
|
||||||
data_times[(i, j)].write(value);
|
data_times[(j, i)].write(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,7 +198,7 @@ struct Constraint {
|
|||||||
survival_prob: f64,
|
survival_prob: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
|
pub fn fit_turnbull(data_times: Matrix2xX<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, se_method: SEMethod, zero_tolerance: f64) -> TurnbullResult {
|
||||||
// ----------------------
|
// ----------------------
|
||||||
// Prepare for regression
|
// Prepare for regression
|
||||||
|
|
||||||
@ -205,18 +206,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
let intervals = get_turnbull_intervals(&data_times);
|
let intervals = get_turnbull_intervals(&data_times);
|
||||||
|
|
||||||
// Recode times as indexes
|
// Recode times as indexes
|
||||||
let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| {
|
let data_time_interval_indexes = code_times_as_indexes(&data_times, &intervals);
|
||||||
let tleft = t[0];
|
|
||||||
let tright = t[1];
|
|
||||||
|
|
||||||
// Left index is first interval >= observation left bound
|
|
||||||
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
|
|
||||||
|
|
||||||
// Right index is last interval <= observation right bound
|
|
||||||
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
|
|
||||||
|
|
||||||
(left_index, right_index)
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
// Initialise p
|
// Initialise p
|
||||||
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
|
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
|
||||||
@ -286,10 +276,10 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
fn get_turnbull_intervals(data_times: &Matrix2xX<f64>) -> Vec<(f64, f64)> {
|
||||||
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
|
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
|
||||||
all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
|
all_time_points.extend(data_times.row(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
|
||||||
all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true)));
|
all_time_points.extend(data_times.row(0).iter().map(|t| (*t, true)));
|
||||||
all_time_points.dedup();
|
all_time_points.dedup();
|
||||||
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
|
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
|
||||||
|
|
||||||
@ -303,6 +293,21 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
|||||||
return intervals;
|
return intervals;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn code_times_as_indexes(data_times: &Matrix2xX<f64>, intervals: &Vec<(f64, f64)>) -> Vec<(usize, usize)> {
|
||||||
|
return data_times.par_column_iter().map(|t| {
|
||||||
|
let tleft = t[0];
|
||||||
|
let tright = t[1];
|
||||||
|
|
||||||
|
// Left index is first interval >= observation left bound
|
||||||
|
let left_index = intervals.iter().enumerate().find(|(_i, (ileft, _))| *ileft >= tleft).unwrap().0;
|
||||||
|
|
||||||
|
// Right index is last interval <= observation right bound
|
||||||
|
let right_index = intervals.iter().enumerate().rev().find(|(_i, (_, iright))| *iright <= tright).unwrap().0;
|
||||||
|
|
||||||
|
(left_index, right_index)
|
||||||
|
}).collect();
|
||||||
|
}
|
||||||
|
|
||||||
fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>, constraint: Option<Constraint>) -> (Vec<f64>, f64) {
|
fn fit_turnbull_estimator(data: &TurnbullData, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64, mut p: Vec<f64>, constraint: Option<Constraint>) -> (Vec<f64>, f64) {
|
||||||
// Pre-compute S, the survival probability at the start of each interval
|
// Pre-compute S, the survival probability at the start of each interval
|
||||||
let mut s = p_to_s(&p);
|
let mut s = p_to_s(&p);
|
||||||
|
Loading…
Reference in New Issue
Block a user