From 6c5ab0dd60b5cc95df3857fb1472d8c8dfeb6317 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Mon, 1 May 2023 00:13:32 +1000 Subject: [PATCH] Transpose data_time_indexes in memory to avoid unnecessary matrix transposition --- src/intcox.rs | 77 +++++++++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/src/intcox.rs b/src/intcox.rs index 6378780..4d09688 100644 --- a/src/intcox.rs +++ b/src/intcox.rs @@ -16,8 +16,8 @@ const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64 -const ROW_LEFT: usize = 0; -const ROW_RIGHT: usize = 1; +const COL_LEFT: usize = 0; +const COL_RIGHT: usize = 1; use core::mem::MaybeUninit; use std::io; @@ -25,7 +25,7 @@ use std::io; use clap::{Args, ValueEnum}; use csv::{Reader, StringRecord}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; -use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX}; +use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2}; use prettytable::{Table, format, row}; use rayon::prelude::*; use serde::{Serialize, Deserialize}; @@ -100,7 +100,7 @@ pub fn main(args: IntCoxArgs) { } } -pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { +pub fn read_data(path: &str) -> (Vec, MatrixXx2, DMatrix) { // Read CSV into memory let headers: StringRecord; let records: Vec; @@ -115,10 +115,12 @@ pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { } // Read data into matrices + // Note: data_times has one ROW per observation, whereas data_indep has one COLUMN per observation + // See comment in IntervalCensoredCoxData - let mut data_times: Matrix2xX> = Matrix2xX::uninit( - Const::<2>, // Left time, right time - Dyn(records.len()) + let mut data_times: MatrixXx2> = MatrixXx2::uninit( + Dyn(records.len()), + Const::<2> // Left time, right time ); // Called "Z" in the paper and "X" in the C++ code @@ -139,7 +141,7 @@ pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { }; if j < 2 { - data_times[(j, i)].write(value); + data_times[(i, j)].write(value); } else { data_indep[(j - 2, i)].write(value); } @@ -156,8 +158,11 @@ pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { } struct IntervalCensoredCoxData { + // BEWARE! data_time_indexes has one ROW per observation, whereas data_indep has one COLUMN per observation + // This improves the speed later by avoiding unnecessary matrix transposition + //data_times: DMatrix, - data_time_indexes: Matrix2xX, + data_time_indexes: MatrixXx2, data_indep: DMatrix, // Cached intermediate values @@ -178,7 +183,7 @@ impl IntervalCensoredCoxData { } } -pub fn fit_interval_censored_cox(data_times: Matrix2xX, mut data_indep: DMatrix, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult { +pub fn fit_interval_censored_cox(data_times: MatrixXx2, mut data_indep: DMatrix, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult { // ---------------------- // Prepare for regression @@ -200,7 +205,7 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX, mut data_indep: DMa // Recode times as indexes // TODO: HashMap? - let data_time_indexes = Matrix2xX::from_iterator(data_times.ncols(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap())); + let data_time_indexes = MatrixXx2::from_iterator(data_times.nrows(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap())); // Initialise β, Λ let mut beta: DVector = DVector::zeros(data_indep.nrows()); @@ -347,38 +352,38 @@ fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector) -> DV return matrix_exp!(data.data_indep.tr_mul(beta)); } -fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_beta: &DVector) -> Matrix2xX { - let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i])); +fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_beta: &DVector) -> MatrixXx2 { + let cumulative_hazard = MatrixXx2::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i])); // Cannot use apply() as different data types - let mut s = Matrix2xX::zeros(data.num_obs()); - s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)))); - s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)))); + let mut s = MatrixXx2::zeros(data.num_obs()); + s.set_column(COL_LEFT, &matrix_exp!((-exp_z_beta).component_mul(&cumulative_hazard.column(0)))); + s.set_column(COL_RIGHT, &matrix_exp!((-exp_z_beta).component_mul(&cumulative_hazard.column(1)))); return s; } -fn log_likelihood_obs(s: &Matrix2xX) -> DVector { - return (s.row(0) - s.row(1)).apply_into(|l| *l = l.ln()).transpose(); +fn log_likelihood_obs(s: &MatrixXx2) -> DVector { + return (s.column(COL_LEFT) - s.column(COL_RIGHT)).apply_into(|l| *l = l.ln()); } -fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_beta: &DVector, s: &Matrix2xX, log_likelihood: f64) -> (DVector, Matrix2xX, f64) { +fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_beta: &DVector, s: &MatrixXx2, log_likelihood: f64) -> (DVector, MatrixXx2, f64) { // Compute gradient w.r.t. lambda let mut lambda_gradient: DVector = DVector::zeros(data.num_times()); for i in 0..data.num_obs() { - let constant_factor = exp_z_beta[i] / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]); - lambda_gradient[data.data_time_indexes[(ROW_LEFT, i)]] -= s[(ROW_LEFT, i)] * constant_factor; - lambda_gradient[data.data_time_indexes[(ROW_RIGHT, i)]] += s[(ROW_RIGHT, i)] * constant_factor; + let constant_factor = exp_z_beta[i] / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]); + lambda_gradient[data.data_time_indexes[(i, COL_LEFT)]] -= s[(i, COL_LEFT)] * constant_factor; + lambda_gradient[data.data_time_indexes[(i, COL_RIGHT)]] += s[(i, COL_RIGHT)] * constant_factor; } // Compute diagonal elements of Hessian w.r.t lambda let mut lambda_hessdiag: DVector = DVector::zeros(data.num_times()); for i in 0..data.num_obs() { - let denominator = s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]; + // TODO: Vectorise? + let denominator = s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]; + let aij_left = -s[(i, COL_LEFT)] * exp_z_beta[i]; + let aij_right = s[(i, COL_RIGHT)] * exp_z_beta[i]; - let aij_left = -s[(ROW_LEFT, i)] * exp_z_beta[i]; - let aij_right = s[(ROW_RIGHT, i)] * exp_z_beta[i]; - - lambda_hessdiag[data.data_time_indexes[(ROW_LEFT, i)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2); - lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2); + lambda_hessdiag[data.data_time_indexes[(i, COL_LEFT)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2); + lambda_hessdiag[data.data_time_indexes[(i, COL_RIGHT)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2); } // Here are the diagonal elements of G, being the negative diagonal elements of the Hessian @@ -421,26 +426,26 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_be } } -fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector, lambda: &DVector, exp_z_beta: &DVector, s: &Matrix2xX) -> DVector { +fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector, lambda: &DVector, exp_z_beta: &DVector, s: &MatrixXx2) -> DVector { // Compute gradient and Hessian w.r.t. beta let mut beta_gradient: DVector = DVector::zeros(data.num_covs()); let mut beta_hessian: DMatrix = DMatrix::zeros(data.num_covs(), data.num_covs()); for i in 0..data.num_obs() { // TODO: Can this be vectorised? Seems unlikely however - let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]]; - let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]]; + let bli = s[(i, COL_LEFT)] * exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_LEFT)]]; + let bri = s[(i, COL_RIGHT)] * exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_RIGHT)]]; // Gradient - let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]); + let z_factor = (bri - bli) / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]); beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i); // Hessian - let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri); - z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli); - z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]; + let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_RIGHT)]] * (s[(i, COL_RIGHT)] - bri); + z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_LEFT)]] * (s[(i, COL_LEFT)] - bli); + z_factor /= s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]; - z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2); + z_factor -= ((bri - bli) / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)])).powi(2); beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose(); }