Improve performance
Mostly, use BLAS functions to reduce unnecessary allocations for intermediate steps
This commit is contained in:
parent
cdc59da178
commit
a58a52f682
@ -334,7 +334,7 @@ fn matrix_exp(v: &mut f64) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
|
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
|
||||||
return (data.data_indep.transpose() * beta).apply_into(matrix_exp);
|
return data.data_indep.tr_mul(beta).apply_into(matrix_exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
|
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
|
||||||
@ -372,15 +372,16 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
|
|||||||
lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
|
lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Here are the diagonal elements of G, being the negative diagonal elements of the Hessian
|
||||||
|
let mut lambda_neghessdiag_nonsingular = -lambda_hessdiag;
|
||||||
|
lambda_neghessdiag_nonsingular.apply(|v| *v = *v + 1e-9); // Add a small epsilon to ensure non-singular
|
||||||
|
|
||||||
// To invert the diagonal matrix G, we simply have diag(1/diag(G))
|
// To invert the diagonal matrix G, we simply have diag(1/diag(G))
|
||||||
let mut lambda_invneghessdiag = lambda_hessdiag.clone();
|
let mut lambda_invneghessdiag = lambda_neghessdiag_nonsingular.clone();
|
||||||
lambda_invneghessdiag.apply(|v| *v = 1.0 / (-*v + 0.0001));
|
lambda_invneghessdiag.apply(|v| *v = 1.0 / *v);
|
||||||
|
|
||||||
let lambda_nr_factors = lambda_invneghessdiag.component_mul(&lambda_gradient);
|
let lambda_nr_factors = lambda_invneghessdiag.component_mul(&lambda_gradient);
|
||||||
|
|
||||||
let mut lambda_weights = lambda_hessdiag.clone();
|
|
||||||
lambda_weights.apply(|v| *v = -*v + 0.0001);
|
|
||||||
|
|
||||||
// Take as large a step as possible while the log-likelihood increases
|
// Take as large a step as possible while the log-likelihood increases
|
||||||
let mut step_size_exponent: i32 = 0;
|
let mut step_size_exponent: i32 = 0;
|
||||||
loop {
|
loop {
|
||||||
@ -388,7 +389,7 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
|
|||||||
let lambda_target = lambda + step_size * &lambda_nr_factors;
|
let lambda_target = lambda + step_size * &lambda_nr_factors;
|
||||||
|
|
||||||
// Do projection step
|
// Do projection step
|
||||||
let mut lambda_new = monotonic_regression_pava(lambda_target, lambda_weights.clone());
|
let mut lambda_new = monotonic_regression_pava(lambda_target, lambda_neghessdiag_nonsingular.clone());
|
||||||
lambda_new.apply(|l| *l = l.max(0.0));
|
lambda_new.apply(|l| *l = l.max(0.0));
|
||||||
|
|
||||||
// Constrain Λ(0) = 0
|
// Constrain Λ(0) = 0
|
||||||
@ -419,7 +420,7 @@ fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVe
|
|||||||
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
|
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
|
||||||
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
|
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
|
||||||
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
||||||
beta_gradient += z_factor * data.data_indep.column(i);
|
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute Hessian w.r.t. beta
|
// Compute Hessian w.r.t. beta
|
||||||
@ -436,7 +437,7 @@ fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVe
|
|||||||
|
|
||||||
z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2);
|
z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2);
|
||||||
|
|
||||||
beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
|
beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut beta_neghess = -beta_hessian;
|
let mut beta_neghess = -beta_hessian;
|
||||||
|
Loading…
Reference in New Issue
Block a user