Improve performance

Avoid unnecessary zeroing of input data matrices
This commit is contained in:
RunasSudo 2023-04-30 15:45:55 +10:00
parent 5876f724ad
commit 14c1944f92
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 26 additions and 17 deletions

View File

@ -19,12 +19,13 @@ const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
const ROW_LEFT: usize = 0;
const ROW_RIGHT: usize = 1;
use core::mem::MaybeUninit;
use std::io;
use clap::{Args, ValueEnum};
use csv::{Reader, StringRecord};
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
use nalgebra::{DMatrix, DVector, Matrix2xX};
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
use prettytable::{Table, format, row};
use rayon::prelude::*;
use serde::{Serialize, Deserialize};
@ -115,15 +116,15 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
// Read data into matrices
let mut data_times: Matrix2xX<f64> = Matrix2xX::zeros(
//2, // Left time, right time
records.len()
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
Const::<2>, // Left time, right time
Dyn(records.len())
);
// Called "Z" in the paper and "X" in the C++ code
let mut data_indep: DMatrix<f64> = DMatrix::zeros(
headers.len() - 2,
records.len()
let mut data_indep: DMatrix<MaybeUninit<f64>> = DMatrix::uninit(
Dyn(headers.len() - 2),
Dyn(records.len())
);
// Parse header row
@ -138,9 +139,9 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
};
if j < 2 {
data_times[(j, i)] = value;
data_times[(j, i)].write(value);
} else {
data_indep[(j - 2, i)] = value;
data_indep[(j - 2, i)].write(value);
}
}
}
@ -148,7 +149,10 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
// TODO: Fail on left time > right time
// TODO: Fail on left time < 0
return (indep_names, data_times, data_indep);
// SAFETY: assume_init is OK because we initialised all values above
unsafe {
return (indep_names, data_times.assume_init(), data_indep.assume_init());
}
}
struct IntervalCensoredCoxData {
@ -328,22 +332,27 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMa
};
}
/// Use with Matrix.apply_into for exponentiation
fn matrix_exp(v: &mut f64) {
*v = v.exp();
macro_rules! matrix_exp {
($matrix: expr) => {
{
let mut matrix = $matrix;
//matrix.data.as_mut_slice().par_iter_mut().for_each(|x| *x = x.exp()); // This is actually slower
matrix.apply(|x| *x = x.exp());
matrix
}
}
}
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
return data.data_indep.tr_mul(beta).apply_into(matrix_exp);
return matrix_exp!(data.data_indep.tr_mul(beta));
}
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
let mut s = Matrix2xX::zeros(data.num_obs());
s.set_row(0, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)).apply_into(matrix_exp));
s.set_row(1, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)).apply_into(matrix_exp));
s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0))));
s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1))));
return s;
}