turnbull: Further refactoring for profiling

This commit is contained in:
RunasSudo 2023-10-22 18:41:40 +11:00
parent f043f7c67d
commit 22a2deca89
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 21 additions and 12 deletions

View File

@ -285,20 +285,10 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
let mut iteration = 1;
loop {
// Get total failure probability for each observation (denominator of μ_ij)
let sum_fail_prob = DVector::from_iterator(
data.num_obs(),
data.data_time_interval_indexes
.iter()
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum())
);
let sum_fail_prob = get_sum_fail_prob(data, &s);
// Compute π_j
let mut pi: DVector<f64> = DVector::zeros(data.num_intervals());
for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() {
for j in *idx_left..(*idx_right + 1) {
pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64;
}
}
let pi = compute_pi(data, &s, sum_fail_prob);
let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
@ -328,6 +318,25 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
return s;
}
fn get_sum_fail_prob(data: &TurnbullData, s: &DVector<f64>) -> DVector<f64> {
return DVector::from_iterator(
data.num_obs(),
data.data_time_interval_indexes
.iter()
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum())
);
}
fn compute_pi(data: &TurnbullData, s: &DVector<f64>, sum_fail_prob: DVector<f64>) -> DVector<f64> {
let mut pi: DVector<f64> = DVector::zeros(data.num_intervals());
for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() {
for j in *idx_left..(*idx_right + 1) {
pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64;
}
}
return pi;
}
fn compute_hessian(data: &TurnbullData, s: &DVector<f64>) -> DMatrix<f64> {
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);