turnbull: Use Vec<f64> throughout for better performance

Further 15% speedup
This commit is contained in:
RunasSudo 2023-10-22 19:09:01 +11:00
parent 8205e4acbc
commit 18a0679476
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 15 additions and 18 deletions

View File

@ -189,7 +189,8 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
}).collect(); }).collect();
// Initialise s // Initialise s
let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64); // Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
let s = vec![1.0 / intervals.len() as f64; intervals.len()];
let mut data = TurnbullData { let mut data = TurnbullData {
data_time_interval_indexes: data_time_interval_indexes, data_time_interval_indexes: data_time_interval_indexes,
@ -204,7 +205,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
progress_bar.reset(); progress_bar.reset();
progress_bar.println("Running iterative algorithm to fit Turnbull estimator"); progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s); let s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s);
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0) // Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1); let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
@ -258,7 +259,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
return TurnbullResult { return TurnbullResult {
failure_intervals: data.intervals, failure_intervals: data.intervals,
failure_prob: s.data.as_vec().clone(), failure_prob: s,
survival_prob: survival_prob, survival_prob: survival_prob,
survival_prob_se: survival_prob_se.data.as_vec().clone(), survival_prob_se: survival_prob_se.data.as_vec().clone(),
}; };
@ -281,7 +282,7 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
return intervals; return intervals;
} }
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: DVector<f64>) -> DVector<f64> { fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: Vec<f64>) -> Vec<f64> {
let mut iteration = 1; let mut iteration = 1;
loop { loop {
// Get total failure probability for each observation (denominator of μ_ij) // Get total failure probability for each observation (denominator of μ_ij)
@ -318,35 +319,31 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
return s; return s;
} }
fn get_sum_fail_prob(data: &TurnbullData, s: &DVector<f64>) -> DVector<f64> { fn get_sum_fail_prob(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
return DVector::from_iterator( return data.data_time_interval_indexes
data.num_obs(),
data.data_time_interval_indexes
.iter() .iter()
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum()) .map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum())
); .collect();
} }
fn compute_pi(data: &TurnbullData, s: &DVector<f64>, sum_fail_prob: DVector<f64>) -> DVector<f64> { fn compute_pi(data: &TurnbullData, s: &Vec<f64>, sum_fail_prob: Vec<f64>) -> Vec<f64> {
// Faster to repeatedly index Vec than DVector, so first work on Vec then convert to DVector at the end
let mut pi: Vec<f64> = vec![0.0; data.num_intervals()]; let mut pi: Vec<f64> = vec![0.0; data.num_intervals()];
let s_vec = s.data.as_vec();
for ((idx_left, idx_right), sum_fail_prob_i) in data.data_time_interval_indexes.iter().zip(sum_fail_prob.iter()) { for ((idx_left, idx_right), sum_fail_prob_i) in data.data_time_interval_indexes.iter().zip(sum_fail_prob.iter()) {
for j in *idx_left..(*idx_right + 1) { for j in *idx_left..(*idx_right + 1) {
pi[j] += s_vec[j] / sum_fail_prob_i / data.num_obs() as f64; pi[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64;
} }
} }
return DVector::from_vec(pi); return pi;
} }
fn compute_hessian(data: &TurnbullData, s: &DVector<f64>) -> DMatrix<f64> { fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() { for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
// Compute 1 / (Σ_j α_{i,j} s_j) // Compute 1 / (Σ_j α_{i,j} s_j)
let mut one_over_hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum(); let mut one_over_hessian_denominator: f64 = s[*idx_left..(*idx_right + 1)].iter().sum();
one_over_hessian_denominator = one_over_hessian_denominator.powi(-2); one_over_hessian_denominator = one_over_hessian_denominator.powi(-2);
// The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1}) // The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1})