turnbull: Use Vec<f64> throughout for better performance

Further 15% speedup
This commit is contained in:
RunasSudo 2023-10-22 19:09:01 +11:00
parent 8205e4acbc
commit 18a0679476
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A

View File

@ -189,7 +189,8 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
}).collect();
// Initialise s
let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64);
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
let s = vec![1.0 / intervals.len() as f64; intervals.len()];
let mut data = TurnbullData {
data_time_interval_indexes: data_time_interval_indexes,
@ -204,7 +205,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
progress_bar.reset();
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s);
let s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s);
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
@ -258,7 +259,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
return TurnbullResult {
failure_intervals: data.intervals,
failure_prob: s.data.as_vec().clone(),
failure_prob: s,
survival_prob: survival_prob,
survival_prob_se: survival_prob_se.data.as_vec().clone(),
};
@ -281,7 +282,7 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
return intervals;
}
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: DVector<f64>) -> DVector<f64> {
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: Vec<f64>) -> Vec<f64> {
let mut iteration = 1;
loop {
// Get total failure probability for each observation (denominator of μ_ij)
@ -318,35 +319,31 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
return s;
}
fn get_sum_fail_prob(data: &TurnbullData, s: &DVector<f64>) -> DVector<f64> {
return DVector::from_iterator(
data.num_obs(),
data.data_time_interval_indexes
fn get_sum_fail_prob(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
return data.data_time_interval_indexes
.iter()
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum())
);
.map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum())
.collect();
}
fn compute_pi(data: &TurnbullData, s: &DVector<f64>, sum_fail_prob: DVector<f64>) -> DVector<f64> {
// Faster to repeatedly index Vec than DVector, so first work on Vec then convert to DVector at the end
fn compute_pi(data: &TurnbullData, s: &Vec<f64>, sum_fail_prob: Vec<f64>) -> Vec<f64> {
let mut pi: Vec<f64> = vec![0.0; data.num_intervals()];
let s_vec = s.data.as_vec();
for ((idx_left, idx_right), sum_fail_prob_i) in data.data_time_interval_indexes.iter().zip(sum_fail_prob.iter()) {
for j in *idx_left..(*idx_right + 1) {
pi[j] += s_vec[j] / sum_fail_prob_i / data.num_obs() as f64;
pi[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64;
}
}
return DVector::from_vec(pi);
return pi;
}
fn compute_hessian(data: &TurnbullData, s: &DVector<f64>) -> DMatrix<f64> {
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
// Compute 1 / (Σ_j α_{i,j} s_j)
let mut one_over_hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
let mut one_over_hessian_denominator: f64 = s[*idx_left..(*idx_right + 1)].iter().sum();
one_over_hessian_denominator = one_over_hessian_denominator.powi(-2);
// The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1})