From 22a2deca8961b2ea92f8197b21c72678cb33c2cf Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 22 Oct 2023 18:41:40 +1100 Subject: [PATCH] turnbull: Further refactoring for profiling --- src/turnbull.rs | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/turnbull.rs b/src/turnbull.rs index 936df85..7713aaf 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -285,20 +285,10 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma let mut iteration = 1; loop { // Get total failure probability for each observation (denominator of μ_ij) - let sum_fail_prob = DVector::from_iterator( - data.num_obs(), - data.data_time_interval_indexes - .iter() - .map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum()) - ); + let sum_fail_prob = get_sum_fail_prob(data, &s); // Compute π_j - let mut pi: DVector = DVector::zeros(data.num_intervals()); - for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() { - for j in *idx_left..(*idx_right + 1) { - pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64; - } - } + let pi = compute_pi(data, &s, sum_fail_prob); let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); @@ -328,6 +318,25 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma return s; } +fn get_sum_fail_prob(data: &TurnbullData, s: &DVector) -> DVector { + return DVector::from_iterator( + data.num_obs(), + data.data_time_interval_indexes + .iter() + .map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum()) + ); +} + +fn compute_pi(data: &TurnbullData, s: &DVector, sum_fail_prob: DVector) -> DVector { + let mut pi: DVector = DVector::zeros(data.num_intervals()); + for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() { + for j in *idx_left..(*idx_right + 1) { + pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64; + } + } + return pi; +} + fn compute_hessian(data: &TurnbullData, s: &DVector) -> DMatrix { let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);