Improve performance

Mostly, use BLAS functions to reduce unnecessary allocations for intermediate steps
This commit is contained in:
RunasSudo 2023-04-29 17:39:25 +10:00
parent cdc59da178
commit a58a52f682
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 10 additions and 9 deletions

View File

@ -334,7 +334,7 @@ fn matrix_exp(v: &mut f64) {
} }
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> { fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
return (data.data_indep.transpose() * beta).apply_into(matrix_exp); return data.data_indep.tr_mul(beta).apply_into(matrix_exp);
} }
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> { fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
@ -372,15 +372,16 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2); lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
} }
// Here are the diagonal elements of G, being the negative diagonal elements of the Hessian
let mut lambda_neghessdiag_nonsingular = -lambda_hessdiag;
lambda_neghessdiag_nonsingular.apply(|v| *v = *v + 1e-9); // Add a small epsilon to ensure non-singular
// To invert the diagonal matrix G, we simply have diag(1/diag(G)) // To invert the diagonal matrix G, we simply have diag(1/diag(G))
let mut lambda_invneghessdiag = lambda_hessdiag.clone(); let mut lambda_invneghessdiag = lambda_neghessdiag_nonsingular.clone();
lambda_invneghessdiag.apply(|v| *v = 1.0 / (-*v + 0.0001)); lambda_invneghessdiag.apply(|v| *v = 1.0 / *v);
let lambda_nr_factors = lambda_invneghessdiag.component_mul(&lambda_gradient); let lambda_nr_factors = lambda_invneghessdiag.component_mul(&lambda_gradient);
let mut lambda_weights = lambda_hessdiag.clone();
lambda_weights.apply(|v| *v = -*v + 0.0001);
// Take as large a step as possible while the log-likelihood increases // Take as large a step as possible while the log-likelihood increases
let mut step_size_exponent: i32 = 0; let mut step_size_exponent: i32 = 0;
loop { loop {
@ -388,7 +389,7 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
let lambda_target = lambda + step_size * &lambda_nr_factors; let lambda_target = lambda + step_size * &lambda_nr_factors;
// Do projection step // Do projection step
let mut lambda_new = monotonic_regression_pava(lambda_target, lambda_weights.clone()); let mut lambda_new = monotonic_regression_pava(lambda_target, lambda_neghessdiag_nonsingular.clone());
lambda_new.apply(|l| *l = l.max(0.0)); lambda_new.apply(|l| *l = l.max(0.0));
// Constrain Λ(0) = 0 // Constrain Λ(0) = 0
@ -419,7 +420,7 @@ fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVe
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]]; let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]]; let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]); let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
beta_gradient += z_factor * data.data_indep.column(i); beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
} }
// Compute Hessian w.r.t. beta // Compute Hessian w.r.t. beta
@ -436,7 +437,7 @@ fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVe
z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2); z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2);
beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose(); beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
} }
let mut beta_neghess = -beta_hessian; let mut beta_neghess = -beta_hessian;