turnbull: Improve efficiency of EM step as suggested by Anderson-Bergman

This commit is contained in:
RunasSudo 2023-10-28 23:28:42 +11:00
parent 37c904bf34
commit 85e3ee0dcd
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 24 additions and 8 deletions

View File

@ -370,16 +370,32 @@ fn get_likelihood_obs(data: &TurnbullData, s: &Vec<f64>) -> Vec<f64> {
}
fn do_em_step(data: &TurnbullData, p: &Vec<f64>, s: &Vec<f64>) -> Vec<f64> {
// Update p
let mut p_new = Vec::with_capacity(data.num_intervals());
for j in 0..data.num_intervals() {
let tmp: f64 = data.data_time_interval_indexes.iter()
.filter(|(idx_left, idx_right)| j >= *idx_left && j <= *idx_right)
.map(|(idx_left, idx_right)| 1.0 / (s[*idx_left] - s[*idx_right + 1]))
.sum();
// Compute contributions to m
let mut m_contrib = vec![0.0; data.num_intervals()];
for (idx_left, idx_right) in data.data_time_interval_indexes.iter() {
let contrib = 1.0 / (s[*idx_left] - s[*idx_right + 1]);
p_new.push(p[j] * tmp / (data.num_obs() as f64));
// Adds to m for the first interval in the observation
m_contrib[*idx_left] += contrib;
// Subtracts from m for the first interval beyond the observation
if *idx_right + 1 < data.num_intervals() {
m_contrib[*idx_right + 1] -= contrib;
}
}
// Compute m
let mut m = Vec::with_capacity(data.num_intervals());
let mut m_last = 0.0;
for m_contrib_j in m_contrib {
let m_next = m_last + m_contrib_j / (data.num_obs() as f64);
m.push(m_next);
m_last = m_next;
}
// Update p
// p := p * m
let p_new = p.par_iter().zip(m.into_par_iter()).map(|(p_j, m_j)| p_j * m_j).collect();
return p_new;
}