Add comments to M step

This commit is contained in:
RunasSudo 2023-04-22 23:57:17 +10:00
parent 461eb8db5f
commit 45585a2b38
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 20 additions and 11 deletions

View File

@ -421,8 +421,8 @@ fn do_m_step(data: &IntervalCensoredCoxData, exp_beta_z: &Matrix1xX<f64>, beta:
// Split these steps into functions to make profiling easier // Split these steps into functions to make profiling easier
let (mut s0, s1, s2) = m_step_compute_s_values(data, xi_exp_beta_z); let (mut s0, s1, s2) = m_step_compute_s_values(data, xi_exp_beta_z);
let sigma = m_step_compute_sigma(data, &posterior_weight, &s0, &s1, &s2); let jacobian = m_step_compute_jacobian(data, &posterior_weight, &s0, &s1, &s2);
let new_beta = m_step_compute_new_beta(data, &posterior_weight, &s0, &s1, sigma, beta); let new_beta = m_step_compute_new_beta(data, &posterior_weight, &s0, &s1, jacobian, beta);
s0 = m_step_compute_s0(data, beta); s0 = m_step_compute_s0(data, beta);
let new_lambda = m_step_compute_new_lambda(data, &posterior_weight, &s0); let new_lambda = m_step_compute_new_lambda(data, &posterior_weight, &s0);
@ -433,6 +433,7 @@ fn m_step_compute_s_values(data: &IntervalCensoredCoxData, xi_exp_beta_z: &Matri
// ComputeSValues // ComputeSValues
// Compute s0 // Compute s0
// For each k, s0 is \sum_{i=1}^n I(t_k <= R*_j) E(ξ_j) exp(β^T Z_jk)
let mut s0: DVector<f64> = DVector::zeros(data.num_times()); // Elements are f64 let mut s0: DVector<f64> = DVector::zeros(data.num_times()); // Elements are f64
for i in 0..data.num_obs() { for i in 0..data.num_obs() {
let s0_contrib = xi_exp_beta_z[i]; let s0_contrib = xi_exp_beta_z[i];
@ -447,6 +448,8 @@ fn m_step_compute_s_values(data: &IntervalCensoredCoxData, xi_exp_beta_z: &Matri
s2_contrib[i] = xi_exp_beta_z[i] * &data.z_z_transpose[i]; // Observations are time-independent s2_contrib[i] = xi_exp_beta_z[i] * &data.z_z_transpose[i]; // Observations are time-independent
} }
// For each k, s1 is \sum_{i=1}^n I(t_k <= R*_j) E(ξ_j) exp(β^T Z_jk) Z_jk
// s1 is also the gradient of s0
let s1 = (0..data.num_times()).into_par_iter().map(|k| { let s1 = (0..data.num_times()).into_par_iter().map(|k| {
let mut s1_k = DVector::zeros(data.num_covs()); let mut s1_k = DVector::zeros(data.num_covs());
for i in 0..data.num_obs() { for i in 0..data.num_obs() {
@ -457,6 +460,7 @@ fn m_step_compute_s_values(data: &IntervalCensoredCoxData, xi_exp_beta_z: &Matri
s1_k s1_k
}).collect(); }).collect();
// For each k, s2 is Jacobian of s1
let s2 = (0..data.num_times()).into_par_iter().map(|k| { let s2 = (0..data.num_times()).into_par_iter().map(|k| {
let mut s2_k = DMatrix::zeros(data.num_covs(), data.num_covs()); let mut s2_k = DMatrix::zeros(data.num_covs(), data.num_covs());
for i in 0..data.num_obs() { for i in 0..data.num_obs() {
@ -470,32 +474,37 @@ fn m_step_compute_s_values(data: &IntervalCensoredCoxData, xi_exp_beta_z: &Matri
return (s0, s1, s2); return (s0, s1, s2);
} }
fn m_step_compute_sigma(data: &IntervalCensoredCoxData, posterior_weight: &DMatrix<f64>, s0: &DVector<f64>, s1: &Vec<DVector<f64>>, s2: &Vec<DMatrix<f64>>) -> DMatrix<f64> { fn m_step_compute_jacobian(data: &IntervalCensoredCoxData, posterior_weight: &DMatrix<f64>, s0: &DVector<f64>, s1: &Vec<DVector<f64>>, s2: &Vec<DMatrix<f64>>) -> DMatrix<f64> {
// ComputeSigma // ComputeSigma
let mut sigma: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs()); let mut jacobian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
for k in 0..data.num_times() { for k in 0..data.num_times() {
// factor_k derives from the quotient rule applied to the fraction in the LHS to be solved for 0
let factor_k = (s1[k].clone() / s0[k]) * (s1[k].transpose() / s0[k]) - (s2[k].clone() / s0[k]); let factor_k = (s1[k].clone() / s0[k]) * (s1[k].transpose() / s0[k]) - (s2[k].clone() / s0[k]);
let sum_posterior_weight = data.r_star_indicator.column(k).component_mul(&posterior_weight.column(k)).sum(); let sum_posterior_weight = data.r_star_indicator.column(k).component_mul(&posterior_weight.column(k)).sum();
sigma += sum_posterior_weight * factor_k.clone(); jacobian += sum_posterior_weight * factor_k.clone();
} }
return sigma; return jacobian;
} }
fn m_step_compute_new_beta(data: &IntervalCensoredCoxData, posterior_weight: &DMatrix<f64>, s0: &DVector<f64>, s1: &Vec<DVector<f64>>, sigma: DMatrix<f64>, beta: &DVector<f64>) -> DVector<f64> { fn m_step_compute_new_beta(data: &IntervalCensoredCoxData, posterior_weight: &DMatrix<f64>, s0: &DVector<f64>, s1: &Vec<DVector<f64>>, jacobian: DMatrix<f64>, beta: &DVector<f64>) -> DVector<f64> {
// ComputeNewBeta // ComputeNewBeta
assert!(sigma.clone().full_piv_lu().is_invertible(), "Sigma is not invertible"); assert!(jacobian.clone().full_piv_lu().is_invertible(), "Jacobian is not invertible");
let mut sum: DVector<f64> = DVector::zeros(data.num_covs()); let mut lhs_value: DVector<f64> = DVector::zeros(data.num_covs());
for k in 0..data.num_times() { for k in 0..data.num_times() {
let quotient_k = s1[k].clone() / s0[k]; let quotient_k = s1[k].clone() / s0[k];
for i in 0..data.num_obs() { for i in 0..data.num_obs() {
if data.r_star_indicator[(i, k)] == 1.0 { if data.r_star_indicator[(i, k)] == 1.0 {
sum += posterior_weight[(i, k)] * (data.data_indep.column(i) - &quotient_k); lhs_value += posterior_weight[(i, k)] * (data.data_indep.column(i) - &quotient_k);
} }
} }
} }
let new_beta = beta.clone() - sigma.try_inverse().unwrap() * sum; // lhs_value = value of the LHS to be solved for 0 vector, \sum_{i=1}^n \sum_{j=1}^k I(t_k <= R*_j) etc...
// jacobian = Jacobian of LHS
// new_beta is therefore obtained by Newton's method
let new_beta = beta.clone() - jacobian.try_inverse().unwrap() * lhs_value;
return new_beta; return new_beta;
} }