Improve performance

Avoid unnecessary zeroing of input data matrices
This commit is contained in:
RunasSudo 2023-04-30 15:45:55 +10:00
parent 5876f724ad
commit 14c1944f92
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 26 additions and 17 deletions

View File

@ -19,12 +19,13 @@ const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
const ROW_LEFT: usize = 0; const ROW_LEFT: usize = 0;
const ROW_RIGHT: usize = 1; const ROW_RIGHT: usize = 1;
use core::mem::MaybeUninit;
use std::io; use std::io;
use clap::{Args, ValueEnum}; use clap::{Args, ValueEnum};
use csv::{Reader, StringRecord}; use csv::{Reader, StringRecord};
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
use nalgebra::{DMatrix, DVector, Matrix2xX}; use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
use prettytable::{Table, format, row}; use prettytable::{Table, format, row};
use rayon::prelude::*; use rayon::prelude::*;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
@ -115,15 +116,15 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
// Read data into matrices // Read data into matrices
let mut data_times: Matrix2xX<f64> = Matrix2xX::zeros( let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
//2, // Left time, right time Const::<2>, // Left time, right time
records.len() Dyn(records.len())
); );
// Called "Z" in the paper and "X" in the C++ code // Called "Z" in the paper and "X" in the C++ code
let mut data_indep: DMatrix<f64> = DMatrix::zeros( let mut data_indep: DMatrix<MaybeUninit<f64>> = DMatrix::uninit(
headers.len() - 2, Dyn(headers.len() - 2),
records.len() Dyn(records.len())
); );
// Parse header row // Parse header row
@ -138,9 +139,9 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
}; };
if j < 2 { if j < 2 {
data_times[(j, i)] = value; data_times[(j, i)].write(value);
} else { } else {
data_indep[(j - 2, i)] = value; data_indep[(j - 2, i)].write(value);
} }
} }
} }
@ -148,7 +149,10 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
// TODO: Fail on left time > right time // TODO: Fail on left time > right time
// TODO: Fail on left time < 0 // TODO: Fail on left time < 0
return (indep_names, data_times, data_indep); // SAFETY: assume_init is OK because we initialised all values above
unsafe {
return (indep_names, data_times.assume_init(), data_indep.assume_init());
}
} }
struct IntervalCensoredCoxData { struct IntervalCensoredCoxData {
@ -328,22 +332,27 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMa
}; };
} }
/// Use with Matrix.apply_into for exponentiation macro_rules! matrix_exp {
fn matrix_exp(v: &mut f64) { ($matrix: expr) => {
*v = v.exp(); {
let mut matrix = $matrix;
//matrix.data.as_mut_slice().par_iter_mut().for_each(|x| *x = x.exp()); // This is actually slower
matrix.apply(|x| *x = x.exp());
matrix
}
}
} }
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> { fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
return data.data_indep.tr_mul(beta).apply_into(matrix_exp); return matrix_exp!(data.data_indep.tr_mul(beta));
} }
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> { fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i])); let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
let mut s = Matrix2xX::zeros(data.num_obs()); let mut s = Matrix2xX::zeros(data.num_obs());
s.set_row(0, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)).apply_into(matrix_exp)); s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0))));
s.set_row(1, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)).apply_into(matrix_exp)); s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1))));
return s; return s;
} }