Improve performance
Avoid unnecessary zeroing of input data matrices
This commit is contained in:
parent
5876f724ad
commit
14c1944f92
@ -19,12 +19,13 @@ const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
|
||||
const ROW_LEFT: usize = 0;
|
||||
const ROW_RIGHT: usize = 1;
|
||||
|
||||
use core::mem::MaybeUninit;
|
||||
use std::io;
|
||||
|
||||
use clap::{Args, ValueEnum};
|
||||
use csv::{Reader, StringRecord};
|
||||
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||
use nalgebra::{DMatrix, DVector, Matrix2xX};
|
||||
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
|
||||
use prettytable::{Table, format, row};
|
||||
use rayon::prelude::*;
|
||||
use serde::{Serialize, Deserialize};
|
||||
@ -115,15 +116,15 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
||||
|
||||
// Read data into matrices
|
||||
|
||||
let mut data_times: Matrix2xX<f64> = Matrix2xX::zeros(
|
||||
//2, // Left time, right time
|
||||
records.len()
|
||||
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
|
||||
Const::<2>, // Left time, right time
|
||||
Dyn(records.len())
|
||||
);
|
||||
|
||||
// Called "Z" in the paper and "X" in the C++ code
|
||||
let mut data_indep: DMatrix<f64> = DMatrix::zeros(
|
||||
headers.len() - 2,
|
||||
records.len()
|
||||
let mut data_indep: DMatrix<MaybeUninit<f64>> = DMatrix::uninit(
|
||||
Dyn(headers.len() - 2),
|
||||
Dyn(records.len())
|
||||
);
|
||||
|
||||
// Parse header row
|
||||
@ -138,9 +139,9 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
||||
};
|
||||
|
||||
if j < 2 {
|
||||
data_times[(j, i)] = value;
|
||||
data_times[(j, i)].write(value);
|
||||
} else {
|
||||
data_indep[(j - 2, i)] = value;
|
||||
data_indep[(j - 2, i)].write(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -148,7 +149,10 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
||||
// TODO: Fail on left time > right time
|
||||
// TODO: Fail on left time < 0
|
||||
|
||||
return (indep_names, data_times, data_indep);
|
||||
// SAFETY: assume_init is OK because we initialised all values above
|
||||
unsafe {
|
||||
return (indep_names, data_times.assume_init(), data_indep.assume_init());
|
||||
}
|
||||
}
|
||||
|
||||
struct IntervalCensoredCoxData {
|
||||
@ -328,22 +332,27 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMa
|
||||
};
|
||||
}
|
||||
|
||||
/// Use with Matrix.apply_into for exponentiation
|
||||
fn matrix_exp(v: &mut f64) {
|
||||
*v = v.exp();
|
||||
macro_rules! matrix_exp {
|
||||
($matrix: expr) => {
|
||||
{
|
||||
let mut matrix = $matrix;
|
||||
//matrix.data.as_mut_slice().par_iter_mut().for_each(|x| *x = x.exp()); // This is actually slower
|
||||
matrix.apply(|x| *x = x.exp());
|
||||
matrix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
|
||||
return data.data_indep.tr_mul(beta).apply_into(matrix_exp);
|
||||
return matrix_exp!(data.data_indep.tr_mul(beta));
|
||||
}
|
||||
|
||||
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
|
||||
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
|
||||
|
||||
let mut s = Matrix2xX::zeros(data.num_obs());
|
||||
s.set_row(0, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)).apply_into(matrix_exp));
|
||||
s.set_row(1, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)).apply_into(matrix_exp));
|
||||
|
||||
s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0))));
|
||||
s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1))));
|
||||
return s;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user