From 85e3ee0dcdeedd683af85181d3a7d992b1485173 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sat, 28 Oct 2023 23:28:42 +1100 Subject: [PATCH] turnbull: Improve efficiency of EM step as suggested by Anderson-Bergman --- src/turnbull.rs | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/turnbull.rs b/src/turnbull.rs index 6760727..7e31a02 100644 --- a/src/turnbull.rs +++ b/src/turnbull.rs @@ -370,17 +370,33 @@ fn get_likelihood_obs(data: &TurnbullData, s: &Vec) -> Vec { } fn do_em_step(data: &TurnbullData, p: &Vec, s: &Vec) -> Vec { - // Update p - let mut p_new = Vec::with_capacity(data.num_intervals()); - for j in 0..data.num_intervals() { - let tmp: f64 = data.data_time_interval_indexes.iter() - .filter(|(idx_left, idx_right)| j >= *idx_left && j <= *idx_right) - .map(|(idx_left, idx_right)| 1.0 / (s[*idx_left] - s[*idx_right + 1])) - .sum(); + // Compute contributions to m + let mut m_contrib = vec![0.0; data.num_intervals()]; + for (idx_left, idx_right) in data.data_time_interval_indexes.iter() { + let contrib = 1.0 / (s[*idx_left] - s[*idx_right + 1]); - p_new.push(p[j] * tmp / (data.num_obs() as f64)); + // Adds to m for the first interval in the observation + m_contrib[*idx_left] += contrib; + + // Subtracts from m for the first interval beyond the observation + if *idx_right + 1 < data.num_intervals() { + m_contrib[*idx_right + 1] -= contrib; + } } + // Compute m + let mut m = Vec::with_capacity(data.num_intervals()); + let mut m_last = 0.0; + for m_contrib_j in m_contrib { + let m_next = m_last + m_contrib_j / (data.num_obs() as f64); + m.push(m_next); + m_last = m_next; + } + + // Update p + // p := p * m + let p_new = p.par_iter().zip(m.into_par_iter()).map(|(p_j, m_j)| p_j * m_j).collect(); + return p_new; }