turnbull: Refactor to aid profiling
This commit is contained in:
parent
dd24de5813
commit
0a8c77fa2c
180
src/turnbull.rs
180
src/turnbull.rs
@ -172,18 +172,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
// Prepare for regression
|
// Prepare for regression
|
||||||
|
|
||||||
// Get Turnbull intervals
|
// Get Turnbull intervals
|
||||||
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
|
let intervals = get_turnbull_intervals(&data_times);
|
||||||
all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
|
|
||||||
all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true)));
|
|
||||||
all_time_points.dedup();
|
|
||||||
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
|
|
||||||
|
|
||||||
let mut intervals: Vec<(f64, f64)> = Vec::new();
|
|
||||||
for i in 1..all_time_points.len() {
|
|
||||||
if all_time_points[i - 1].1 == true && all_time_points[i].1 == false {
|
|
||||||
intervals.push((all_time_points[i - 1].0, all_time_points[i].0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recode times as indexes
|
// Recode times as indexes
|
||||||
let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| {
|
let data_time_interval_indexes: Vec<(usize, usize)> = data_times.row_iter().map(|t| {
|
||||||
@ -202,7 +191,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
// Initialise s
|
// Initialise s
|
||||||
let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64);
|
let mut s = DVector::repeat(intervals.len(), 1.0 / intervals.len() as f64);
|
||||||
|
|
||||||
let data = TurnbullData {
|
let mut data = TurnbullData {
|
||||||
data_time_interval_indexes: data_time_interval_indexes,
|
data_time_interval_indexes: data_time_interval_indexes,
|
||||||
intervals: intervals,
|
intervals: intervals,
|
||||||
};
|
};
|
||||||
@ -215,48 +204,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
progress_bar.reset();
|
progress_bar.reset();
|
||||||
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
|
progress_bar.println("Running iterative algorithm to fit Turnbull estimator");
|
||||||
|
|
||||||
let mut iteration = 1;
|
s = fit_turnbull_estimator(&mut data, progress_bar.clone(), max_iterations, fail_prob_tolerance, s);
|
||||||
loop {
|
|
||||||
// Get total failure probability for each observation (denominator of μ_ij)
|
|
||||||
let sum_fail_prob = DVector::from_iterator(
|
|
||||||
data.num_obs(),
|
|
||||||
data.data_time_interval_indexes
|
|
||||||
.iter()
|
|
||||||
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum())
|
|
||||||
);
|
|
||||||
|
|
||||||
// Compute π_j
|
|
||||||
let mut pi: DVector<f64> = DVector::zeros(data.num_intervals());
|
|
||||||
for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() {
|
|
||||||
for j in *idx_left..(*idx_right + 1) {
|
|
||||||
pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
|
|
||||||
|
|
||||||
let converged = largest_delta_s <= fail_prob_tolerance;
|
|
||||||
|
|
||||||
s = pi;
|
|
||||||
|
|
||||||
// Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations
|
|
||||||
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
|
|
||||||
let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64;
|
|
||||||
|
|
||||||
// Update progress bar
|
|
||||||
progress_bar.set_position(progress_bar.position().max(progress3.max(progress2)));
|
|
||||||
progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s));
|
|
||||||
|
|
||||||
if converged {
|
|
||||||
progress_bar.println(format!("Converged in {} iterations", iteration));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
iteration += 1;
|
|
||||||
if iteration > max_iterations {
|
|
||||||
panic!("Exceeded --max-iterations");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
|
// Get survival probabilities (1 - cumulative failure probability), excluding at t=0 (prob=1) and t=inf (prob=0)
|
||||||
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
|
let mut survival_prob: Vec<f64> = Vec::with_capacity(data.num_intervals() - 1);
|
||||||
@ -274,32 +222,7 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
progress_bar.reset();
|
progress_bar.reset();
|
||||||
progress_bar.println("Computing standard errors for survival probabilities");
|
progress_bar.println("Computing standard errors for survival probabilities");
|
||||||
|
|
||||||
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
|
let hessian = compute_hessian(&data, progress_bar.clone(), &s);
|
||||||
|
|
||||||
for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) {
|
|
||||||
let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
|
|
||||||
hessian_denominator = hessian_denominator.powi(2);
|
|
||||||
|
|
||||||
let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case
|
|
||||||
let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian
|
|
||||||
|
|
||||||
for h in idx_start..idx_end {
|
|
||||||
let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 };
|
|
||||||
let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 };
|
|
||||||
|
|
||||||
hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator;
|
|
||||||
|
|
||||||
for k in idx_start..h {
|
|
||||||
let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 };
|
|
||||||
let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 };
|
|
||||||
|
|
||||||
let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator;
|
|
||||||
hessian[(h, k)] -= value;
|
|
||||||
hessian[(k, h)] -= value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
progress_bar.finish();
|
|
||||||
|
|
||||||
let mut survival_prob_se: DVector<f64>;
|
let mut survival_prob_se: DVector<f64>;
|
||||||
|
|
||||||
@ -346,6 +269,101 @@ pub fn fit_turnbull(data_times: MatrixXx2<f64>, progress_bar: ProgressBar, max_i
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_turnbull_intervals(data_times: &MatrixXx2<f64>) -> Vec<(f64, f64)> {
|
||||||
|
let mut all_time_points: Vec<(f64, bool)> = Vec::new(); // Vec of (time, is_left)
|
||||||
|
all_time_points.extend(data_times.column(1).iter().map(|t| (*t, false))); // So we have right bounds before left bounds when sorted - ensures correct behaviour since intervals are left-open
|
||||||
|
all_time_points.extend(data_times.column(0).iter().map(|t| (*t, true)));
|
||||||
|
all_time_points.dedup();
|
||||||
|
all_time_points.sort_by(|(t1, _), (t2, _)| t1.partial_cmp(t2).unwrap());
|
||||||
|
|
||||||
|
let mut intervals: Vec<(f64, f64)> = Vec::new();
|
||||||
|
for i in 1..all_time_points.len() {
|
||||||
|
if all_time_points[i - 1].1 == true && all_time_points[i].1 == false {
|
||||||
|
intervals.push((all_time_points[i - 1].0, all_time_points[i].0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return intervals;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fit_turnbull_estimator(data: &mut TurnbullData, progress_bar: ProgressBar, max_iterations: u32, fail_prob_tolerance: f64, mut s: DVector<f64>) -> DVector<f64> {
|
||||||
|
let mut iteration = 1;
|
||||||
|
loop {
|
||||||
|
// Get total failure probability for each observation (denominator of μ_ij)
|
||||||
|
let sum_fail_prob = DVector::from_iterator(
|
||||||
|
data.num_obs(),
|
||||||
|
data.data_time_interval_indexes
|
||||||
|
.iter()
|
||||||
|
.map(|(idx_left, idx_right)| s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum())
|
||||||
|
);
|
||||||
|
|
||||||
|
// Compute π_j
|
||||||
|
let mut pi: DVector<f64> = DVector::zeros(data.num_intervals());
|
||||||
|
for (i, (idx_left, idx_right)) in data.data_time_interval_indexes.iter().enumerate() {
|
||||||
|
for j in *idx_left..(*idx_right + 1) {
|
||||||
|
pi[j] += s[j] / sum_fail_prob[i] / data.num_obs() as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let largest_delta_s = s.iter().zip(pi.iter()).map(|(x, y)| (y - x).abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
|
||||||
|
|
||||||
|
let converged = largest_delta_s <= fail_prob_tolerance;
|
||||||
|
|
||||||
|
s = pi;
|
||||||
|
|
||||||
|
// Estimate progress bar according to either the order of magnitude of the largest_delta_s relative to tolerance, or iteration/max_iterations
|
||||||
|
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
|
||||||
|
let progress3 = ((-largest_delta_s.log10()).max(0.0) / -fail_prob_tolerance.log10() * u64::MAX as f64) as u64;
|
||||||
|
|
||||||
|
// Update progress bar
|
||||||
|
progress_bar.set_position(progress_bar.position().max(progress3.max(progress2)));
|
||||||
|
progress_bar.set_message(format!("Iteration {} (max Δs = {:.4})", iteration + 1, largest_delta_s));
|
||||||
|
|
||||||
|
if converged {
|
||||||
|
progress_bar.println(format!("Converged in {} iterations", iteration));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration += 1;
|
||||||
|
if iteration > max_iterations {
|
||||||
|
panic!("Exceeded --max-iterations");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_hessian(data: &TurnbullData, progress_bar: ProgressBar, s: &DVector<f64>) -> DMatrix<f64> {
|
||||||
|
let mut hessian: DMatrix<f64> = DMatrix::zeros(data.num_intervals() - 1, data.num_intervals() - 1);
|
||||||
|
|
||||||
|
for (idx_left, idx_right) in data.data_time_interval_indexes.iter().progress_with(progress_bar.clone()) {
|
||||||
|
let mut hessian_denominator = s.view((*idx_left, 0), (*idx_right - *idx_left + 1, 1)).sum();
|
||||||
|
hessian_denominator = hessian_denominator.powi(2);
|
||||||
|
|
||||||
|
let idx_start = if *idx_left > 0 { *idx_left - 1 } else { 0 }; // To cover the h+1 case
|
||||||
|
let idx_end = (*idx_right + 1).min(data.num_intervals() - 1); // Go up to and including idx_right but don't go beyond hessian
|
||||||
|
|
||||||
|
for h in idx_start..idx_end {
|
||||||
|
let i_h = if h >= *idx_left && h <= *idx_right { 1.0 } else { 0.0 };
|
||||||
|
let i_h1 = if h + 1 >= *idx_left && h + 1 <= *idx_right { 1.0 } else { 0.0 };
|
||||||
|
|
||||||
|
hessian[(h, h)] -= (i_h - i_h1) * (i_h - i_h1) / hessian_denominator;
|
||||||
|
|
||||||
|
for k in idx_start..h {
|
||||||
|
let i_k = if k >= *idx_left && k <= *idx_right { 1.0 } else { 0.0 };
|
||||||
|
let i_k1 = if k + 1 >= *idx_left && k + 1 <= *idx_right { 1.0 } else { 0.0 };
|
||||||
|
|
||||||
|
let value = (i_h - i_h1) * (i_k - i_k1) / hessian_denominator;
|
||||||
|
hessian[(h, k)] -= value;
|
||||||
|
hessian[(k, h)] -= value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
progress_bar.finish();
|
||||||
|
|
||||||
|
return hessian;
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct TurnbullResult {
|
pub struct TurnbullResult {
|
||||||
pub failure_intervals: Vec<(f64, f64)>,
|
pub failure_intervals: Vec<(f64, f64)>,
|
||||||
|
Loading…
Reference in New Issue
Block a user