Tidy update_beta
This commit is contained in:
parent
a58a52f682
commit
a1bb1568ad
@ -413,24 +413,20 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>) -> DVector<f64> {
|
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>) -> DVector<f64> {
|
||||||
// Compute gradient w.r.t. beta
|
// Compute gradient and Hessian w.r.t. beta
|
||||||
let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs());
|
let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs());
|
||||||
|
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
|
||||||
|
|
||||||
for i in 0..data.num_obs() {
|
for i in 0..data.num_obs() {
|
||||||
// TODO: Vectorise
|
// TODO: Can this be vectorised? Seems unlikely however
|
||||||
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
|
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
|
||||||
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
|
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
|
||||||
|
|
||||||
|
// Gradient
|
||||||
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
||||||
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
|
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
|
||||||
}
|
|
||||||
|
|
||||||
// Compute Hessian w.r.t. beta
|
|
||||||
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
|
|
||||||
for i in 0..data.num_obs() {
|
|
||||||
// TODO: Vectorise
|
|
||||||
// TODO: bli, bri same as above
|
|
||||||
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
|
|
||||||
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
|
|
||||||
|
|
||||||
|
// Hessian
|
||||||
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri);
|
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri);
|
||||||
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli);
|
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli);
|
||||||
z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
|
z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
|
||||||
|
Loading…
Reference in New Issue
Block a user