Transpose data_time_indexes in memory to avoid unnecessary matrix transposition
This commit is contained in:
parent
1c08116f10
commit
6c5ab0dd60
@ -16,8 +16,8 @@
|
|||||||
|
|
||||||
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
|
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
|
||||||
|
|
||||||
const ROW_LEFT: usize = 0;
|
const COL_LEFT: usize = 0;
|
||||||
const ROW_RIGHT: usize = 1;
|
const COL_RIGHT: usize = 1;
|
||||||
|
|
||||||
use core::mem::MaybeUninit;
|
use core::mem::MaybeUninit;
|
||||||
use std::io;
|
use std::io;
|
||||||
@ -25,7 +25,7 @@ use std::io;
|
|||||||
use clap::{Args, ValueEnum};
|
use clap::{Args, ValueEnum};
|
||||||
use csv::{Reader, StringRecord};
|
use csv::{Reader, StringRecord};
|
||||||
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||||
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
|
use nalgebra::{Const, DMatrix, DVector, Dyn, MatrixXx2};
|
||||||
use prettytable::{Table, format, row};
|
use prettytable::{Table, format, row};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Serialize, Deserialize};
|
||||||
@ -100,7 +100,7 @@ pub fn main(args: IntCoxArgs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
pub fn read_data(path: &str) -> (Vec<String>, MatrixXx2<f64>, DMatrix<f64>) {
|
||||||
// Read CSV into memory
|
// Read CSV into memory
|
||||||
let headers: StringRecord;
|
let headers: StringRecord;
|
||||||
let records: Vec<StringRecord>;
|
let records: Vec<StringRecord>;
|
||||||
@ -115,10 +115,12 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read data into matrices
|
// Read data into matrices
|
||||||
|
// Note: data_times has one ROW per observation, whereas data_indep has one COLUMN per observation
|
||||||
|
// See comment in IntervalCensoredCoxData
|
||||||
|
|
||||||
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
|
let mut data_times: MatrixXx2<MaybeUninit<f64>> = MatrixXx2::uninit(
|
||||||
Const::<2>, // Left time, right time
|
Dyn(records.len()),
|
||||||
Dyn(records.len())
|
Const::<2> // Left time, right time
|
||||||
);
|
);
|
||||||
|
|
||||||
// Called "Z" in the paper and "X" in the C++ code
|
// Called "Z" in the paper and "X" in the C++ code
|
||||||
@ -139,7 +141,7 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if j < 2 {
|
if j < 2 {
|
||||||
data_times[(j, i)].write(value);
|
data_times[(i, j)].write(value);
|
||||||
} else {
|
} else {
|
||||||
data_indep[(j - 2, i)].write(value);
|
data_indep[(j - 2, i)].write(value);
|
||||||
}
|
}
|
||||||
@ -156,8 +158,11 @@ pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct IntervalCensoredCoxData {
|
struct IntervalCensoredCoxData {
|
||||||
|
// BEWARE! data_time_indexes has one ROW per observation, whereas data_indep has one COLUMN per observation
|
||||||
|
// This improves the speed later by avoiding unnecessary matrix transposition
|
||||||
|
|
||||||
//data_times: DMatrix<f64>,
|
//data_times: DMatrix<f64>,
|
||||||
data_time_indexes: Matrix2xX<usize>,
|
data_time_indexes: MatrixXx2<usize>,
|
||||||
data_indep: DMatrix<f64>,
|
data_indep: DMatrix<f64>,
|
||||||
|
|
||||||
// Cached intermediate values
|
// Cached intermediate values
|
||||||
@ -178,7 +183,7 @@ impl IntervalCensoredCoxData {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMatrix<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult {
|
pub fn fit_interval_censored_cox(data_times: MatrixXx2<f64>, mut data_indep: DMatrix<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult {
|
||||||
// ----------------------
|
// ----------------------
|
||||||
// Prepare for regression
|
// Prepare for regression
|
||||||
|
|
||||||
@ -200,7 +205,7 @@ pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMa
|
|||||||
|
|
||||||
// Recode times as indexes
|
// Recode times as indexes
|
||||||
// TODO: HashMap?
|
// TODO: HashMap?
|
||||||
let data_time_indexes = Matrix2xX::from_iterator(data_times.ncols(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap()));
|
let data_time_indexes = MatrixXx2::from_iterator(data_times.nrows(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap()));
|
||||||
|
|
||||||
// Initialise β, Λ
|
// Initialise β, Λ
|
||||||
let mut beta: DVector<f64> = DVector::zeros(data_indep.nrows());
|
let mut beta: DVector<f64> = DVector::zeros(data_indep.nrows());
|
||||||
@ -347,38 +352,38 @@ fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DV
|
|||||||
return matrix_exp!(data.data_indep.tr_mul(beta));
|
return matrix_exp!(data.data_indep.tr_mul(beta));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
|
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> MatrixXx2<f64> {
|
||||||
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
|
let cumulative_hazard = MatrixXx2::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i])); // Cannot use apply() as different data types
|
||||||
|
|
||||||
let mut s = Matrix2xX::zeros(data.num_obs());
|
let mut s = MatrixXx2::zeros(data.num_obs());
|
||||||
s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0))));
|
s.set_column(COL_LEFT, &matrix_exp!((-exp_z_beta).component_mul(&cumulative_hazard.column(0))));
|
||||||
s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1))));
|
s.set_column(COL_RIGHT, &matrix_exp!((-exp_z_beta).component_mul(&cumulative_hazard.column(1))));
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log_likelihood_obs(s: &Matrix2xX<f64>) -> DVector<f64> {
|
fn log_likelihood_obs(s: &MatrixXx2<f64>) -> DVector<f64> {
|
||||||
return (s.row(0) - s.row(1)).apply_into(|l| *l = l.ln()).transpose();
|
return (s.column(COL_LEFT) - s.column(COL_RIGHT)).apply_into(|l| *l = l.ln());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>, log_likelihood: f64) -> (DVector<f64>, Matrix2xX<f64>, f64) {
|
fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &MatrixXx2<f64>, log_likelihood: f64) -> (DVector<f64>, MatrixXx2<f64>, f64) {
|
||||||
// Compute gradient w.r.t. lambda
|
// Compute gradient w.r.t. lambda
|
||||||
let mut lambda_gradient: DVector<f64> = DVector::zeros(data.num_times());
|
let mut lambda_gradient: DVector<f64> = DVector::zeros(data.num_times());
|
||||||
for i in 0..data.num_obs() {
|
for i in 0..data.num_obs() {
|
||||||
let constant_factor = exp_z_beta[i] / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
let constant_factor = exp_z_beta[i] / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]);
|
||||||
lambda_gradient[data.data_time_indexes[(ROW_LEFT, i)]] -= s[(ROW_LEFT, i)] * constant_factor;
|
lambda_gradient[data.data_time_indexes[(i, COL_LEFT)]] -= s[(i, COL_LEFT)] * constant_factor;
|
||||||
lambda_gradient[data.data_time_indexes[(ROW_RIGHT, i)]] += s[(ROW_RIGHT, i)] * constant_factor;
|
lambda_gradient[data.data_time_indexes[(i, COL_RIGHT)]] += s[(i, COL_RIGHT)] * constant_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute diagonal elements of Hessian w.r.t lambda
|
// Compute diagonal elements of Hessian w.r.t lambda
|
||||||
let mut lambda_hessdiag: DVector<f64> = DVector::zeros(data.num_times());
|
let mut lambda_hessdiag: DVector<f64> = DVector::zeros(data.num_times());
|
||||||
for i in 0..data.num_obs() {
|
for i in 0..data.num_obs() {
|
||||||
let denominator = s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
|
// TODO: Vectorise?
|
||||||
|
let denominator = s[(i, COL_LEFT)] - s[(i, COL_RIGHT)];
|
||||||
|
let aij_left = -s[(i, COL_LEFT)] * exp_z_beta[i];
|
||||||
|
let aij_right = s[(i, COL_RIGHT)] * exp_z_beta[i];
|
||||||
|
|
||||||
let aij_left = -s[(ROW_LEFT, i)] * exp_z_beta[i];
|
lambda_hessdiag[data.data_time_indexes[(i, COL_LEFT)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2);
|
||||||
let aij_right = s[(ROW_RIGHT, i)] * exp_z_beta[i];
|
lambda_hessdiag[data.data_time_indexes[(i, COL_RIGHT)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
|
||||||
|
|
||||||
lambda_hessdiag[data.data_time_indexes[(ROW_LEFT, i)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2);
|
|
||||||
lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Here are the diagonal elements of G, being the negative diagonal elements of the Hessian
|
// Here are the diagonal elements of G, being the negative diagonal elements of the Hessian
|
||||||
@ -421,26 +426,26 @@ fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_be
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>) -> DVector<f64> {
|
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &MatrixXx2<f64>) -> DVector<f64> {
|
||||||
// Compute gradient and Hessian w.r.t. beta
|
// Compute gradient and Hessian w.r.t. beta
|
||||||
let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs());
|
let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs());
|
||||||
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
|
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
|
||||||
|
|
||||||
for i in 0..data.num_obs() {
|
for i in 0..data.num_obs() {
|
||||||
// TODO: Can this be vectorised? Seems unlikely however
|
// TODO: Can this be vectorised? Seems unlikely however
|
||||||
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
|
let bli = s[(i, COL_LEFT)] * exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_LEFT)]];
|
||||||
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
|
let bri = s[(i, COL_RIGHT)] * exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_RIGHT)]];
|
||||||
|
|
||||||
// Gradient
|
// Gradient
|
||||||
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
let z_factor = (bri - bli) / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)]);
|
||||||
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
|
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
|
||||||
|
|
||||||
// Hessian
|
// Hessian
|
||||||
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri);
|
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_RIGHT)]] * (s[(i, COL_RIGHT)] - bri);
|
||||||
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli);
|
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(i, COL_LEFT)]] * (s[(i, COL_LEFT)] - bli);
|
||||||
z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
|
z_factor /= s[(i, COL_LEFT)] - s[(i, COL_RIGHT)];
|
||||||
|
|
||||||
z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2);
|
z_factor -= ((bri - bli) / (s[(i, COL_LEFT)] - s[(i, COL_RIGHT)])).powi(2);
|
||||||
|
|
||||||
beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
|
beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user