diff --git a/src/turnbull.rs b/src/turnbull.rs index ffc2c0c..020e80d 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -189,7 +189,8 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i }).collect(); // Initialise s - let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64); + // Faster to repeatedly index Vec than DVector, and we don't do any matrix arithmetic, so represent this as Vec + let s = vec![1.0 / intervals.len() as f64; intervals.len()]; let mut data = TurnbullData { data_time_interval_indexes: data_time_interval_indexes, @@ -204,7 +205,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i progress_bar.reset(); progress_bar.println("Running iterative algorithm to fit Turnbull estimator"); - s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s); + let s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s); // Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0) let mut survival_prob: Vec = Vec::with_capacity(data.num_intervals() - 1); @@ -258,7 +259,7 @@ pub fn fit_turnbull(data_times: MatrixXx2, progress_bar: ProgressBar, max_i return TurnbullResult { failure_intervals: data.intervals, - failure_prob: s.data.as_vec().clone(), + failure_prob: s, survival_prob: survival_prob, survival_prob_se: survival_prob_se.data.as_vec().clone(), }; @@ -281,7 +282,7 @@ fn get_turnbull_intervals(data_times: &MatrixXx2) -> Vec<(f64, f64)> { return intervals; } -fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: DVector) -> DVector { +fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: Vec) -> Vec { let mut iteration = 1; loop { // Get total failure probability for each observation (denominator of μ_ij) @@ -318,35 +319,31 @@ fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, ma return s; } -fn get_sum_fail_prob(data: &TurnbullData, s: &DVector) -> DVector { - return DVector::from_iterator( - data.num_obs(), - data.data_time_interval_indexes - .iter() - .map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum()) - ); +fn get_sum_fail_prob(data: &TurnbullData, s: &Vec) -> Vec { + return data.data_time_interval_indexes + .iter() + .map(|(idx_left, idx_right)| s[*idx_left..(*idx_right + 1)].iter().sum()) + .collect(); } -fn compute_pi(data: &TurnbullData, s: &DVector, sum_fail_prob: DVector) -> DVector { - // Faster to repeatedly index Vec than DVector, so first work on Vec then convert to DVector at the end +fn compute_pi(data: &TurnbullData, s: &Vec, sum_fail_prob: Vec) -> Vec { let mut pi: Vec = vec![0.0; data.num_intervals()]; - let s_vec = s.data.as_vec(); for ((idx_left, idx_right), sum_fail_prob_i) in data.data_time_interval_indexes.iter().zip(sum_fail_prob.iter()) { for j in *idx_left..(*idx_right + 1) { - pi[j] += s_vec[j] / sum_fail_prob_i / data.num_obs() as f64; + pi[j] += s[j] / sum_fail_prob_i / data.num_obs() as f64; } } - return DVector::from_vec(pi); + return pi; } -fn compute_hessian(data: &TurnbullData, s: &DVector) -> DMatrix { +fn compute_hessian(data: &TurnbullData, s: &Vec) -> DMatrix { let mut hessian: DMatrix = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1); for (idx_left, idx_right) in data.data_time_interval_indexes.iter() { // Compute 1 / (Σ_j α_{i,j} s_j) - let mut one_over_hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum(); + let mut one_over_hessian_denominator: f64 = s[*idx_left..(*idx_right + 1)].iter().sum(); one_over_hessian_denominator = one_over_hessian_denominator.powi(-2); // The numerator of the log-likelihood is -(α_{i,h} - α_{i,h+1})(α_{i,k} - α_{i,k+1})