turnbull: Use Vec<f64> throughout for better performance
Further 15% speedup
This commit is contained in:
parent
8205e4acbc
commit
18a0679476
@ -189,7 +189,8 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
||||
}).collect();
|
||||
|
||||
// Initialise s
|
||||
let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64);
|
||||
// Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec
|
||||
let s = vec![1.0 / intervals.len() as f64; intervals.len()];
|
||||
|
||||
let mut data = TurnbullData {
|
||||
data_time_interval_indexes: data_time_interval_indexes,
|
||||
@ -204,7 +205,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
||||
progress_bar.reset();
|
||||
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
|
||||
|
||||
s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s);
|
||||
let s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s);
|
||||
|
||||
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
|
||||
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
|
||||
@ -258,7 +259,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
||||
|
||||
return TurnbullResult {
|
||||
failure_intervals: data.intervals,
|
||||
failure_prob: s.data.as_vec().clone(),
|
||||
failure_prob: s,
|
||||
survival_prob: survival_prob,
|
||||
survival_prob_se: survival_prob_se.data.as_vec().clone(),
|
||||
};
|
||||
@ -281,7 +282,7 @@ fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
||||
return intervals;
|
||||
}
|
||||
|
||||
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: DVector<f64>) -> DVector<f64> {
|
||||
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: Vec<f64>) -> Vec<f64> {
|
||||
let mut iteration = 1;
|
||||
loop {
|
||||
// Get total failure probability for each observation (denominator of μ_ij)
|
||||
@ -318,35 +319,31 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma
|
||||
return s;
|
||||
}
|
||||
|
||||
fn get_sum_fail_prob(data: &TurnbullData, s: &DVector<f64>) -> DVector<f64> {
|
||||
return DVector::from_iterator(
|
||||
data.num_obs(),
|
||||
data.data_time_interval_indexes
|
||||
.iter()
|
||||
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum())
|
||||
);
|
||||
fn get_sum_fail_prob(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
|
||||
return data.data_time_interval_indexes
|
||||
.iter()
|
||||
.map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum())
|
||||
.collect();
|
||||
}
|
||||
|
||||
fn compute_pi(data: &TurnbullData, s: &DVector<f64>, sum_fail_prob: DVector<f64>) -> DVector<f64> {
|
||||
// Faster to repeatedly index Vec than DVector, so first work on Vec then convert to DVector at the end
|
||||
fn compute_pi(data: &TurnbullData, s: &Vec<f64>, sum_fail_prob: Vec<f64>) -> Vec<f64> {
|
||||
let mut pi: Vec<f64> = vec![0.0; data.num_intervals()];
|
||||
let s_vec = s.data.as_vec();
|
||||
|
||||
for ((idx_left, idx_right), sum_fail_prob_i) in data.data_time_interval_indexes.iter().zip(sum_fail_prob.iter()) {
|
||||
for j in *idx_left..(*idx_right + 1) {
|
||||
pi[j] += s_vec[j] / sum_fail_prob_i / data.num_obs() as f64;
|
||||
pi[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64;
|
||||
}
|
||||
}
|
||||
|
||||
return DVector::from_vec(pi);
|
||||
return pi;
|
||||
}
|
||||
|
||||
fn compute_hessian(data: &TurnbullData, s: &DVector<f64>) -> DMatrix<f64> {
|
||||
fn compute_hessian(data: &TurnbullData, s: &Vec<f64>) -> DMatrix<f64> {
|
||||
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
|
||||
|
||||
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
|
||||
// Compute 1 / (Σ_j α_{i,j} s_j)
|
||||
let mut one_over_hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
|
||||
let mut one_over_hessian_denominator: f64 = s[*idx_left..(*idx_right + 1)].iter().sum();
|
||||
one_over_hessian_denominator = one_over_hessian_denominator.powi(-2);
|
||||
|
||||
// The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1})
|
||||
|
Loading…
x
Reference in New Issue
Block a user