From 14c1944f92f681aee6f92d63a72fe6cf6a5775fc Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 30 Apr 2023 15:45:55 +1000 Subject: [PATCH] Improve performance Avoid unnecessary zeroing of input data matrices --- src/intcox.rs | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/intcox.rs b/src/intcox.rs index e32a669..b56f1df 100644 --- a/src/intcox.rs +++ b/src/intcox.rs @@ -19,12 +19,13 @@ const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64 const ROW_LEFT: usize = 0; const ROW_RIGHT: usize = 1; +use core::mem::MaybeUninit; use std::io; use clap::{Args, ValueEnum}; use csv::{Reader, StringRecord}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; -use nalgebra::{DMatrix, DVector, Matrix2xX}; +use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX}; use prettytable::{Table, format, row}; use rayon::prelude::*; use serde::{Serialize, Deserialize}; @@ -115,15 +116,15 @@ pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { // Read data into matrices - let mut data_times: Matrix2xX = Matrix2xX::zeros( - //2, // Left time, right time - records.len() + let mut data_times: Matrix2xX> = Matrix2xX::uninit( + Const::<2>, // Left time, right time + Dyn(records.len()) ); // Called "Z" in the paper and "X" in the C++ code - let mut data_indep: DMatrix = DMatrix::zeros( - headers.len() - 2, - records.len() + let mut data_indep: DMatrix> = DMatrix::uninit( + Dyn(headers.len() - 2), + Dyn(records.len()) ); // Parse header row @@ -138,9 +139,9 @@ pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { }; if j < 2 { - data_times[(j, i)] = value; + data_times[(j, i)].write(value); } else { - data_indep[(j - 2, i)] = value; + data_indep[(j - 2, i)].write(value); } } } @@ -148,7 +149,10 @@ pub fn read_data(path: &str) -> (Vec, Matrix2xX, DMatrix) { // TODO: Fail on left time > right time // TODO: Fail on left time < 0 - return (indep_names, data_times, data_indep); + // SAFETY: assume_init is OK because we initialised all values above + unsafe { + return (indep_names, data_times.assume_init(), data_indep.assume_init()); + } } struct IntervalCensoredCoxData { @@ -328,22 +332,27 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX, mut data_indep: DMa }; } -/// Use with Matrix.apply_into for exponentiation -fn matrix_exp(v: &mut f64) { - *v = v.exp(); +macro_rules! matrix_exp { + ($matrix: expr) => { + { + let mut matrix = $matrix; + //matrix.data.as_mut_slice().par_iter_mut().for_each(|x| *x = x.exp()); // This is actually slower + matrix.apply(|x| *x = x.exp()); + matrix + } + } } fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector) -> DVector { - return data.data_indep.tr_mul(beta).apply_into(matrix_exp); + return matrix_exp!(data.data_indep.tr_mul(beta)); } fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector, exp_z_beta: &DVector) -> Matrix2xX { let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i])); let mut s = Matrix2xX::zeros(data.num_obs()); - s.set_row(0, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)).apply_into(matrix_exp)); - s.set_row(1, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)).apply_into(matrix_exp)); - + s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)))); + s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)))); return s; }