Tidy update_beta

This commit is contained in:
RunasSudo 2023-04-29 18:29:33 +10:00
parent a58a52f682
commit a1bb1568ad
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A

View File

@ -413,24 +413,20 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
} }
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>) -> DVector<f64> { fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>) -> DVector<f64> {
// Compute gradient w.r.t. beta // Compute gradient and Hessian w.r.t. beta
let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs()); let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs());
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
for i in 0..data.num_obs() { for i in 0..data.num_obs() {
// TODO: Vectorise // TODO: Can this be vectorised? Seems unlikely however
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]]; let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]]; let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
// Gradient
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]); let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i); beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
}
// Compute Hessian w.r.t. beta
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
for i in 0..data.num_obs() {
// TODO: Vectorise
// TODO: bli, bri same as above
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
// Hessian
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri); let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri);
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli); z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli);
z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]; z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];