Transpose data_time_indexes in memory to avoid unnecessary matrix transposition

This commit is contained in:
RunasSudo 2023-05-01 00:13:32 +10:00
parent 1c08116f10
commit 6c5ab0dd60
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
1 changed files with 41 additions and 36 deletions

View File

@ -16,8 +16,8 @@
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
const ROW_LEFT: usize = 0;
const ROW_RIGHT: usize = 1;
const COL_LEFT: usize = 0;
const COL_RIGHT: usize = 1;
use core::mem::MaybeUninit;
use std::io;
@ -25,7 +25,7 @@ use std::io;
use clap::{Args, ValueEnum};
use csv::{Reader, StringRecord};
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
use prettytable::{Table, format, row};
use rayon::prelude::*;
use serde::{Serialize, Deserialize};
@ -100,7 +100,7 @@ pub fn main(args: IntCoxArgs) {
}
}
pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
pub fn read_data(path: &str) -> (Vec<String>, MatrixXx2<f64>, DMatrix<f64>) {
// Read CSV into memory
let headers: StringRecord;
let records: Vec<StringRecord>;
@ -115,10 +115,12 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
}
// Read data into matrices
// Note: data_times has one ROW per observation, whereas data_indep has one COLUMN per observation
// See comment in IntervalCensoredCoxData
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
Const::<2>, // Left time, right time
Dyn(records.len())
let mut data_times: MatrixXx2<MaybeUninit<f64>> = MatrixXx2::uninit(
Dyn(records.len()),
Const::<2> // Left time, right time
);
// Called "Z" in the paper and "X" in the C++ code
@ -139,7 +141,7 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
};
if j < 2 {
data_times[(j, i)].write(value);
data_times[(i, j)].write(value);
} else {
data_indep[(j - 2, i)].write(value);
}
@ -156,8 +158,11 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
}
struct IntervalCensoredCoxData {
// BEWARE! data_time_indexes has one ROW per observation, whereas data_indep has one COLUMN per observation
// This improves the speed later by avoiding unnecessary matrix transposition
//data_times: DMatrix<f64>,
data_time_indexes: Matrix2xX<usize>,
data_time_indexes: MatrixXx2<usize>,
data_indep: DMatrix<f64>,
// Cached intermediate values
@ -178,7 +183,7 @@ impl IntervalCensoredCoxData {
}
}
pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMatrix<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult {
pub fn fit_interval_censored_cox(data_times: MatrixXx2<f64>, mut data_indep: DMatrix<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult {
// ----------------------
// Prepare for regression
@ -200,7 +205,7 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMa
// Recode times as indexes
// TODO: HashMap?
let data_time_indexes = Matrix2xX::from_iterator(data_times.ncols(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap()));
let data_time_indexes = MatrixXx2::from_iterator(data_times.nrows(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap()));
// Initialise β, Λ
let mut beta: DVector<f64> = DVector::zeros(data_indep.nrows());
@ -347,38 +352,38 @@ fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DV
return matrix_exp!(data.data_indep.tr_mul(beta));
}
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> MatrixXx2<f64> {
let cumulative_hazard = MatrixXx2::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i])); // Cannot use apply() as different data types
let mut s = Matrix2xX::zeros(data.num_obs());
s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0))));
s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1))));
let mut s = MatrixXx2::zeros(data.num_obs());
s.set_column(COL_LEFT, &matrix_exp!((-exp_z_beta).component_mul(&cumulative_hazard.column(0))));
s.set_column(COL_RIGHT, &matrix_exp!((-exp_z_beta).component_mul(&cumulative_hazard.column(1))));
return s;
}
fn log_likelihood_obs(s: &Matrix2xX<f64>) -> DVector<f64> {
return (s.row(0) - s.row(1)).apply_into(|l| *l = l.ln()).transpose();
fn log_likelihood_obs(s: &MatrixXx2<f64>) -> DVector<f64> {
return (s.column(COL_LEFT) - s.column(COL_RIGHT)).apply_into(|l| *l = l.ln());
}
fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>, log_likelihood: f64) -> (DVector<f64>, Matrix2xX<f64>, f64) {
fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &MatrixXx2<f64>, log_likelihood: f64) -> (DVector<f64>, MatrixXx2<f64>, f64) {
// Compute gradient w.r.t. lambda
let mut lambda_gradient: DVector<f64> = DVector::zeros(data.num_times());
for i in 0..data.num_obs() {
let constant_factor = exp_z_beta[i] / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
lambda_gradient[data.data_time_indexes[(ROW_LEFT, i)]] -= s[(ROW_LEFT, i)] * constant_factor;
lambda_gradient[data.data_time_indexes[(ROW_RIGHT, i)]] += s[(ROW_RIGHT, i)] * constant_factor;
let constant_factor = exp_z_beta[i] / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]);
lambda_gradient[data.data_time_indexes[(i, COL_LEFT)]] -= s[(i, COL_LEFT)] * constant_factor;
lambda_gradient[data.data_time_indexes[(i, COL_RIGHT)]] += s[(i, COL_RIGHT)] * constant_factor;
}
// Compute diagonal elements of Hessian w.r.t lambda
let mut lambda_hessdiag: DVector<f64> = DVector::zeros(data.num_times());
for i in 0..data.num_obs() {
let denominator = s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
// TODO: Vectorise?
let denominator = s[(i, COL_LEFT)] - s[(i, COL_RIGHT)];
let aij_left = -s[(i, COL_LEFT)] * exp_z_beta[i];
let aij_right = s[(i, COL_RIGHT)] * exp_z_beta[i];
let aij_left = -s[(ROW_LEFT, i)] * exp_z_beta[i];
let aij_right = s[(ROW_RIGHT, i)] * exp_z_beta[i];
lambda_hessdiag[data.data_time_indexes[(ROW_LEFT, i)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2);
lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
lambda_hessdiag[data.data_time_indexes[(i, COL_LEFT)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2);
lambda_hessdiag[data.data_time_indexes[(i, COL_RIGHT)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
}
// Here are the diagonal elements of G, being the negative diagonal elements of the Hessian
@ -421,26 +426,26 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
}
}
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>) -> DVector<f64> {
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &MatrixXx2<f64>) -> DVector<f64> {
// Compute gradient and Hessian w.r.t. beta
let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs());
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
for i in 0..data.num_obs() {
// TODO: Can this be vectorised? Seems unlikely however
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
let bli = s[(i, COL_LEFT)] * exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_LEFT)]];
let bri = s[(i, COL_RIGHT)] * exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_RIGHT)]];
// Gradient
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
let z_factor = (bri - bli) / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]);
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
// Hessian
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri);
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli);
z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_RIGHT)]] * (s[(i, COL_RIGHT)] - bri);
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_LEFT)]] * (s[(i, COL_LEFT)] - bli);
z_factor /= s[(i, COL_LEFT)] - s[(i, COL_RIGHT)];
z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2);
z_factor -= ((bri - bli) / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)])).powi(2);
beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
}