Improve performance
Avoid unnecessary zeroing of input data matrices
This commit is contained in:
parent
5876f724ad
commit
14c1944f92
@ -19,12 +19,13 @@ const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
|
|||||||
const ROW_LEFT: usize = 0;
|
const ROW_LEFT: usize = 0;
|
||||||
const ROW_RIGHT: usize = 1;
|
const ROW_RIGHT: usize = 1;
|
||||||
|
|
||||||
|
use core::mem::MaybeUninit;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
|
||||||
use clap::{Args, ValueEnum};
|
use clap::{Args, ValueEnum};
|
||||||
use csv::{Reader, StringRecord};
|
use csv::{Reader, StringRecord};
|
||||||
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||||
use nalgebra::{DMatrix, DVector, Matrix2xX};
|
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
|
||||||
use prettytable::{Table, format, row};
|
use prettytable::{Table, format, row};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
@ -115,15 +116,15 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
|||||||
|
|
||||||
// Read data into matrices
|
// Read data into matrices
|
||||||
|
|
||||||
let mut data_times: Matrix2xX<f64> = Matrix2xX::zeros(
|
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
|
||||||
//2, // Left time, right time
|
Const::<2>, // Left time, right time
|
||||||
records.len()
|
Dyn(records.len())
|
||||||
);
|
);
|
||||||
|
|
||||||
// Called "Z" in the paper and "X" in the C++ code
|
// Called "Z" in the paper and "X" in the C++ code
|
||||||
let mut data_indep: DMatrix<f64> = DMatrix::zeros(
|
let mut data_indep: DMatrix<MaybeUninit<f64>> = DMatrix::uninit(
|
||||||
headers.len() - 2,
|
Dyn(headers.len() - 2),
|
||||||
records.len()
|
Dyn(records.len())
|
||||||
);
|
);
|
||||||
|
|
||||||
// Parse header row
|
// Parse header row
|
||||||
@ -138,9 +139,9 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if j < 2 {
|
if j < 2 {
|
||||||
data_times[(j, i)] = value;
|
data_times[(j, i)].write(value);
|
||||||
} else {
|
} else {
|
||||||
data_indep[(j - 2, i)] = value;
|
data_indep[(j - 2, i)].write(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -148,7 +149,10 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
|||||||
// TODO: Fail on left time > right time
|
// TODO: Fail on left time > right time
|
||||||
// TODO: Fail on left time < 0
|
// TODO: Fail on left time < 0
|
||||||
|
|
||||||
return (indep_names, data_times, data_indep);
|
// SAFETY: assume_init is OK because we initialised all values above
|
||||||
|
unsafe {
|
||||||
|
return (indep_names, data_times.assume_init(), data_indep.assume_init());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IntervalCensoredCoxData {
|
struct IntervalCensoredCoxData {
|
||||||
@ -328,22 +332,27 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMa
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Use with Matrix.apply_into for exponentiation
|
macro_rules! matrix_exp {
|
||||||
fn matrix_exp(v: &mut f64) {
|
($matrix: expr) => {
|
||||||
*v = v.exp();
|
{
|
||||||
|
let mut matrix = $matrix;
|
||||||
|
//matrix.data.as_mut_slice().par_iter_mut().for_each(|x| *x = x.exp()); // This is actually slower
|
||||||
|
matrix.apply(|x| *x = x.exp());
|
||||||
|
matrix
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
|
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
|
||||||
return data.data_indep.tr_mul(beta).apply_into(matrix_exp);
|
return matrix_exp!(data.data_indep.tr_mul(beta));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
|
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
|
||||||
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
|
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
|
||||||
|
|
||||||
let mut s = Matrix2xX::zeros(data.num_obs());
|
let mut s = Matrix2xX::zeros(data.num_obs());
|
||||||
s.set_row(0, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0)).apply_into(matrix_exp));
|
s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0))));
|
||||||
s.set_row(1, &(-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1)).apply_into(matrix_exp));
|
s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1))));
|
||||||
|
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user