2023-04-17 17:50:43 +10:00
|
|
|
// hpstat: High-performance statistics implementations
|
|
|
|
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
|
|
|
|
//
|
|
|
|
// This program is free software: you can redistribute it and/or modify
|
|
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
|
|
// (at your option) any later version.
|
|
|
|
//
|
|
|
|
// This program is distributed in the hope that it will be useful,
|
|
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
// GNU Affero General Public License for more details.
|
|
|
|
//
|
|
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
|
|
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
const ROW_LEFT: usize = 0;
|
|
|
|
const ROW_RIGHT: usize = 1;
|
|
|
|
|
2023-04-30 15:45:55 +10:00
|
|
|
use core::mem::MaybeUninit;
|
2023-04-17 17:50:43 +10:00
|
|
|
use std::io;
|
|
|
|
|
|
|
|
use clap::{Args, ValueEnum};
|
2023-04-21 17:39:24 +10:00
|
|
|
use csv::{Reader, StringRecord};
|
2023-04-21 17:21:33 +10:00
|
|
|
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
2023-04-30 15:45:55 +10:00
|
|
|
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX};
|
2023-04-17 17:50:43 +10:00
|
|
|
use prettytable::{Table, format, row};
|
|
|
|
use rayon::prelude::*;
|
|
|
|
use serde::{Serialize, Deserialize};
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
use crate::pava::monotonic_regression_pava;
|
2023-04-21 17:21:33 +10:00
|
|
|
use crate::term::UnconditionalTermLike;
|
|
|
|
|
2023-04-17 17:50:43 +10:00
|
|
|
#[derive(Args)]
|
|
|
|
pub struct IntCoxArgs {
|
|
|
|
/// Path to CSV input file containing the observations
|
|
|
|
#[arg()]
|
|
|
|
input: String,
|
|
|
|
|
|
|
|
/// Output format
|
|
|
|
#[arg(long, value_enum, default_value="text")]
|
|
|
|
output: OutputFormat,
|
|
|
|
|
2023-04-30 15:46:13 +10:00
|
|
|
/// Maximum number of iterations to attempt
|
2023-04-28 01:02:09 +10:00
|
|
|
#[arg(long, default_value="1000")]
|
2023-04-17 17:50:43 +10:00
|
|
|
max_iterations: u32,
|
|
|
|
|
2023-04-30 15:46:13 +10:00
|
|
|
/// Terminate algorithm when the absolute change in log-likelihood is less than this tolerance
|
2023-04-28 01:02:09 +10:00
|
|
|
#[arg(long, default_value="0.01")]
|
|
|
|
ll_tolerance: f64,
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(ValueEnum, Clone)]
|
|
|
|
enum OutputFormat {
|
|
|
|
Text,
|
|
|
|
Json
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn main(args: IntCoxArgs) {
|
|
|
|
// Read data
|
2023-04-21 17:39:24 +10:00
|
|
|
let (indep_names, data_times, data_indep) = read_data(&args.input);
|
2023-04-17 17:50:43 +10:00
|
|
|
|
|
|
|
// Fit regression
|
2023-04-21 17:21:33 +10:00
|
|
|
let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr())));
|
2023-04-28 01:02:09 +10:00
|
|
|
let result = fit_interval_censored_cox(data_times, data_indep, progress_bar, args.max_iterations, args.ll_tolerance);
|
2023-04-17 17:50:43 +10:00
|
|
|
|
|
|
|
// Display output
|
|
|
|
match args.output {
|
|
|
|
OutputFormat::Text => {
|
|
|
|
println!();
|
|
|
|
println!();
|
|
|
|
println!("LL-Model = {:.5}", result.ll_model);
|
|
|
|
println!("LL-Null = {:.5}", result.ll_null);
|
|
|
|
|
|
|
|
let mut summary = Table::new();
|
|
|
|
let format = format::FormatBuilder::new()
|
|
|
|
.separators(&[format::LinePosition::Top, format::LinePosition::Title, format::LinePosition::Bottom], format::LineSeparator::new('-', '+', '+', '+'))
|
|
|
|
.padding(2, 2)
|
|
|
|
.build();
|
|
|
|
summary.set_format(format);
|
|
|
|
|
|
|
|
summary.set_titles(row!["Parameter", c->"β", c->"Std Err.", c->"exp(β)", H2c->"(95% CI)"]);
|
|
|
|
for (i, indep_name) in indep_names.iter().enumerate() {
|
|
|
|
summary.add_row(row![
|
|
|
|
indep_name,
|
|
|
|
r->format!("{:.5}", result.params[i]),
|
|
|
|
r->format!("{:.5}", result.params_se[i]),
|
|
|
|
r->format!("{:.5}", result.params[i].exp()),
|
|
|
|
r->format!("({:.5},", (result.params[i] - Z_97_5 * result.params_se[i]).exp()),
|
|
|
|
format!("{:.5})", (result.params[i] + Z_97_5 * result.params_se[i]).exp()),
|
|
|
|
]);
|
|
|
|
}
|
|
|
|
summary.printstd();
|
|
|
|
}
|
|
|
|
OutputFormat::Json => {
|
|
|
|
println!("{}", serde_json::to_string(&result).unwrap());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
pub fn read_data(path: &str) -> (Vec<String>, Matrix2xX<f64>, DMatrix<f64>) {
|
2023-04-21 17:39:24 +10:00
|
|
|
// Read CSV into memory
|
|
|
|
let headers: StringRecord;
|
|
|
|
let records: Vec<StringRecord>;
|
|
|
|
if path == "-" {
|
|
|
|
let mut csv_reader = Reader::from_reader(io::stdin());
|
|
|
|
headers = csv_reader.headers().unwrap().clone();
|
|
|
|
records = csv_reader.records().map(|r| r.unwrap()).collect();
|
|
|
|
} else {
|
|
|
|
let mut csv_reader = Reader::from_path(path).unwrap();
|
|
|
|
headers = csv_reader.headers().unwrap().clone();
|
|
|
|
records = csv_reader.records().map(|r| r.unwrap()).collect();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Read data into matrices
|
|
|
|
|
2023-04-30 15:45:55 +10:00
|
|
|
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit(
|
|
|
|
Const::<2>, // Left time, right time
|
|
|
|
Dyn(records.len())
|
2023-04-21 17:39:24 +10:00
|
|
|
);
|
|
|
|
|
|
|
|
// Called "Z" in the paper and "X" in the C++ code
|
2023-04-30 15:45:55 +10:00
|
|
|
let mut data_indep: DMatrix<MaybeUninit<f64>> = DMatrix::uninit(
|
|
|
|
Dyn(headers.len() - 2),
|
|
|
|
Dyn(records.len())
|
2023-04-21 17:39:24 +10:00
|
|
|
);
|
|
|
|
|
|
|
|
// Parse header row
|
|
|
|
let indep_names: Vec<String> = headers.iter().skip(2).map(String::from).collect();
|
|
|
|
|
|
|
|
// Parse data
|
|
|
|
for (i, row) in records.iter().enumerate() {
|
|
|
|
for (j, item) in row.iter().enumerate() {
|
|
|
|
let value = match item {
|
|
|
|
"inf" => f64::INFINITY,
|
|
|
|
_ => item.parse().expect("Malformed float")
|
|
|
|
};
|
|
|
|
|
|
|
|
if j < 2 {
|
2023-04-30 15:45:55 +10:00
|
|
|
data_times[(j, i)].write(value);
|
2023-04-21 17:39:24 +10:00
|
|
|
} else {
|
2023-04-30 15:45:55 +10:00
|
|
|
data_indep[(j - 2, i)].write(value);
|
2023-04-21 17:39:24 +10:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// TODO: Fail on left time > right time
|
|
|
|
// TODO: Fail on left time < 0
|
|
|
|
|
2023-04-30 15:45:55 +10:00
|
|
|
// SAFETY: assume_init is OK because we initialised all values above
|
|
|
|
unsafe {
|
|
|
|
return (indep_names, data_times.assume_init(), data_indep.assume_init());
|
|
|
|
}
|
2023-04-21 17:39:24 +10:00
|
|
|
}
|
|
|
|
|
2023-04-17 17:50:43 +10:00
|
|
|
struct IntervalCensoredCoxData {
|
2023-04-28 01:02:09 +10:00
|
|
|
//data_times: DMatrix<f64>,
|
|
|
|
data_time_indexes: Matrix2xX<usize>,
|
2023-04-17 17:50:43 +10:00
|
|
|
data_indep: DMatrix<f64>,
|
|
|
|
|
|
|
|
// Cached intermediate values
|
|
|
|
time_points: Vec<f64>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl IntervalCensoredCoxData {
|
|
|
|
fn num_obs(&self) -> usize {
|
|
|
|
return self.data_indep.ncols();
|
|
|
|
}
|
|
|
|
|
|
|
|
fn num_covs(&self) -> usize {
|
|
|
|
return self.data_indep.nrows();
|
|
|
|
}
|
|
|
|
|
|
|
|
fn num_times(&self) -> usize {
|
|
|
|
return self.time_points.len();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
pub fn fit_interval_censored_cox(data_times: Matrix2xX<f64>, mut data_indep: DMatrix<f64>, progress_bar: ProgressBar, max_iterations: u32, ll_tolerance: f64) -> IntervalCensoredCoxResult {
|
2023-04-17 17:50:43 +10:00
|
|
|
// ----------------------
|
|
|
|
// Prepare for regression
|
|
|
|
|
|
|
|
// Standardise values
|
|
|
|
let indep_means = data_indep.column_mean();
|
|
|
|
let indep_stdev = data_indep.column_variance().apply_into(|x| { *x = (*x * data_indep.ncols() as f64 / (data_indep.ncols() - 1) as f64).sqrt(); });
|
|
|
|
for j in 0..data_indep.nrows() {
|
|
|
|
data_indep.row_mut(j).apply(|x| *x = (*x - indep_means[j]) / indep_stdev[j]);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get time points (t_0 = 0, t_1, ..., t_m)
|
2023-04-28 01:02:09 +10:00
|
|
|
// TODO: Reimplement Turnbull intervals
|
2023-04-17 17:50:43 +10:00
|
|
|
let mut time_points: Vec<f64>;
|
2023-04-28 01:02:09 +10:00
|
|
|
time_points = data_times.iter().copied().collect();
|
|
|
|
time_points.push(0.0); // Ensure 0 is in the list
|
|
|
|
//time_points.push(f64::INFINITY); // Ensure infinity is on the list
|
|
|
|
time_points.sort_by(|a, b| a.partial_cmp(b).unwrap()); // Cannot use .sort() as f64 does not implement Ord
|
|
|
|
time_points.dedup();
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// Recode times as indexes
|
|
|
|
// TODO: HashMap?
|
|
|
|
let data_time_indexes = Matrix2xX::from_iterator(data_times.ncols(), data_times.iter().map(|t| time_points.iter().position(|x| x == t).unwrap()));
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// Initialise β, Λ
|
|
|
|
let mut beta: DVector<f64> = DVector::zeros(data_indep.nrows());
|
|
|
|
let mut lambda: DVector<f64> = DVector::from_iterator(time_points.len(), (0..time_points.len()).map(|i| i as f64 / time_points.len() as f64));
|
2023-04-17 17:50:43 +10:00
|
|
|
|
|
|
|
let data = IntervalCensoredCoxData {
|
2023-04-28 01:02:09 +10:00
|
|
|
//data_times: data_times,
|
|
|
|
data_time_indexes: data_time_indexes,
|
2023-04-17 17:50:43 +10:00
|
|
|
data_indep: data_indep,
|
|
|
|
time_points: time_points,
|
|
|
|
};
|
|
|
|
|
|
|
|
// -------------------
|
2023-04-28 01:02:09 +10:00
|
|
|
// Apply ICM algorithm
|
|
|
|
|
|
|
|
let mut exp_z_beta = compute_exp_z_beta(&data, &beta);
|
|
|
|
let mut s = compute_s(&data, &lambda, &exp_z_beta);
|
|
|
|
let mut ll_model = log_likelihood_obs(&s).sum();
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} {msg}").unwrap());
|
2023-04-17 17:50:43 +10:00
|
|
|
progress_bar.set_length(u64::MAX);
|
|
|
|
progress_bar.reset();
|
2023-04-28 01:02:09 +10:00
|
|
|
progress_bar.println("Running ICM/NR algorithm to fit interval-censored Cox model");
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut iteration = 1;
|
2023-04-17 17:50:43 +10:00
|
|
|
loop {
|
2023-04-28 01:02:09 +10:00
|
|
|
// Update lambda
|
|
|
|
let lambda_new;
|
|
|
|
(lambda_new, s, _) = update_lambda(&data, &lambda, &exp_z_beta, &s, ll_model);
|
|
|
|
|
|
|
|
// Update beta
|
|
|
|
let beta_new = update_beta(&data, &beta, &lambda_new, &exp_z_beta, &s);
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// Compute new log-likelihood
|
|
|
|
exp_z_beta = compute_exp_z_beta(&data, &beta_new);
|
|
|
|
s = compute_s(&data, &lambda_new, &exp_z_beta);
|
|
|
|
let ll_model_new = log_likelihood_obs(&s).sum();
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut converged = true;
|
|
|
|
let ll_change = ll_model_new - ll_model;
|
|
|
|
if ll_change > ll_tolerance {
|
|
|
|
converged = false;
|
|
|
|
}
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
lambda = lambda_new;
|
|
|
|
beta = beta_new;
|
|
|
|
ll_model = ll_model_new;
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// Estimate progress bar according to either the order of magnitude of the ll_change relative to tolerance, or iteration/max_iterations
|
2023-04-17 17:50:43 +10:00
|
|
|
let progress2 = (iteration as f64 / max_iterations as f64 * u64::MAX as f64) as u64;
|
2023-04-28 01:02:09 +10:00
|
|
|
let progress3 = ((-ll_change.log10()).max(0.0) / -ll_tolerance.log10() * u64::MAX as f64) as u64;
|
2023-04-23 18:36:28 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// Update progress bar
|
|
|
|
progress_bar.set_position(progress_bar.position().max(progress3.max(progress2)));
|
|
|
|
progress_bar.set_message(format!("Iteration {} (LL = {:.4}, ΔLL = {:.4})", iteration + 1, ll_model, ll_change));
|
2023-04-17 17:50:43 +10:00
|
|
|
|
|
|
|
if converged {
|
2023-04-28 01:02:09 +10:00
|
|
|
progress_bar.println(format!("ICM/NR converged in {} iterations", iteration));
|
2023-04-17 17:50:43 +10:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
iteration += 1;
|
2023-04-28 01:02:09 +10:00
|
|
|
if iteration > max_iterations {
|
2023-04-17 17:50:43 +10:00
|
|
|
panic!("Exceeded --max-iterations");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Unstandardise betas
|
|
|
|
let mut beta_unstandardised: DVector<f64> = DVector::zeros(data.num_covs());
|
|
|
|
for (j, beta_value) in beta.iter().enumerate() {
|
|
|
|
beta_unstandardised[j] = beta_value / indep_stdev[j];
|
|
|
|
}
|
|
|
|
|
|
|
|
// -------------------------
|
|
|
|
// Compute covariance matrix
|
|
|
|
|
|
|
|
// Compute profile log-likelihoods
|
|
|
|
let h = 5.0 / (data.num_obs() as f64).sqrt(); // "a constant of order n^(-1/2)"
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
progress_bar.set_style(ProgressStyle::with_template("[{elapsed_precise}] {bar:40} Profile LL {pos}/{len}").unwrap());
|
2023-04-17 17:50:43 +10:00
|
|
|
progress_bar.set_length(data.num_covs() as u64 + 2);
|
|
|
|
progress_bar.reset();
|
|
|
|
progress_bar.println("Profiling log-likelihood to compute covariance matrix");
|
|
|
|
|
|
|
|
// ll_null = log-likelihood for null model
|
|
|
|
// pll_toggle_zero = log-likelihoods for each observation at final beta
|
|
|
|
// pll_toggle_one = log-likelihoods for each observation at toggled beta
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let ll_null = profile_log_likelihood_obs(&data, DVector::zeros(data.num_covs()), lambda.clone(), max_iterations, ll_tolerance).sum();
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let pll_toggle_zero: DVector<f64> = profile_log_likelihood_obs(&data, beta.clone(), lambda.clone(), max_iterations, ll_tolerance);
|
2023-04-17 17:50:43 +10:00
|
|
|
progress_bar.inc(1);
|
|
|
|
|
|
|
|
let pll_toggle_one: Vec<DVector<f64>> = (0..data.num_covs()).into_par_iter().map(|j| {
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut pll_beta = beta.clone();
|
|
|
|
pll_beta[j] += h;
|
|
|
|
profile_log_likelihood_obs(&data, pll_beta, lambda.clone(), max_iterations, ll_tolerance)
|
|
|
|
})
|
|
|
|
.progress_with(progress_bar.clone())
|
|
|
|
.collect();
|
2023-04-17 17:50:43 +10:00
|
|
|
progress_bar.finish();
|
|
|
|
|
|
|
|
let mut pll_matrix: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
|
|
|
|
for i in 0..data.num_obs() {
|
|
|
|
let toggle_none_i = pll_toggle_zero[i];
|
|
|
|
let mut ps_i: DVector<f64> = DVector::zeros(data.num_covs());
|
|
|
|
for p in 0..data.num_covs() {
|
|
|
|
ps_i[p] = (pll_toggle_one[p][i] - toggle_none_i) / h;
|
|
|
|
}
|
|
|
|
pll_matrix += ps_i.clone() * ps_i.transpose();
|
|
|
|
}
|
|
|
|
|
|
|
|
let vcov = pll_matrix.try_inverse().expect("Matrix not invertible");
|
|
|
|
|
|
|
|
// Unstandardise SEs
|
|
|
|
let beta_se = vcov.diagonal().apply_into(|x| { *x = x.sqrt(); });
|
|
|
|
let mut beta_se_unstandardised: DVector<f64> = DVector::zeros(data.num_covs());
|
|
|
|
for (j, se) in beta_se.iter().enumerate() {
|
|
|
|
beta_se_unstandardised[j] = se / indep_stdev[j];
|
|
|
|
}
|
|
|
|
|
|
|
|
return IntervalCensoredCoxResult {
|
|
|
|
params: beta_unstandardised.data.as_vec().clone(),
|
|
|
|
params_se: beta_se_unstandardised.data.as_vec().clone(),
|
2023-04-28 01:02:09 +10:00
|
|
|
cumulative_hazard: lambda.data.as_vec().clone(),
|
2023-04-18 16:18:19 +10:00
|
|
|
cumulative_hazard_times: data.time_points,
|
2023-04-17 17:50:43 +10:00
|
|
|
ll_model: ll_model,
|
|
|
|
ll_null: ll_null,
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
2023-04-30 15:45:55 +10:00
|
|
|
macro_rules! matrix_exp {
|
|
|
|
($matrix: expr) => {
|
|
|
|
{
|
|
|
|
let mut matrix = $matrix;
|
|
|
|
//matrix.data.as_mut_slice().par_iter_mut().for_each(|x| *x = x.exp()); // This is actually slower
|
|
|
|
matrix.apply(|x| *x = x.exp());
|
|
|
|
matrix
|
|
|
|
}
|
|
|
|
}
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
fn compute_exp_z_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>) -> DVector<f64> {
|
2023-04-30 15:45:55 +10:00
|
|
|
return matrix_exp!(data.data_indep.tr_mul(beta));
|
2023-04-28 01:02:09 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
fn compute_s(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>) -> Matrix2xX<f64> {
|
|
|
|
let cumulative_hazard = Matrix2xX::from_iterator(data.num_obs(), data.data_time_indexes.iter().map(|i| lambda[*i]));
|
|
|
|
|
|
|
|
let mut s = Matrix2xX::zeros(data.num_obs());
|
2023-04-30 15:45:55 +10:00
|
|
|
s.set_row(ROW_LEFT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(0))));
|
|
|
|
s.set_row(ROW_RIGHT, &matrix_exp!((-exp_z_beta).transpose().component_mul(&cumulative_hazard.row(1))));
|
2023-04-17 17:50:43 +10:00
|
|
|
return s;
|
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
fn log_likelihood_obs(s: &Matrix2xX<f64>) -> DVector<f64> {
|
|
|
|
return (s.row(0) - s.row(1)).apply_into(|l| *l = l.ln()).transpose();
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
fn update_lambda(data: &IntervalCensoredCoxData, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>, log_likelihood: f64) -> (DVector<f64>, Matrix2xX<f64>, f64) {
|
|
|
|
// Compute gradient w.r.t. lambda
|
|
|
|
let mut lambda_gradient: DVector<f64> = DVector::zeros(data.num_times());
|
2023-04-17 17:50:43 +10:00
|
|
|
for i in 0..data.num_obs() {
|
2023-04-28 01:02:09 +10:00
|
|
|
let constant_factor = exp_z_beta[i] / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
|
|
|
lambda_gradient[data.data_time_indexes[(ROW_LEFT, i)]] -= s[(ROW_LEFT, i)] * constant_factor;
|
|
|
|
lambda_gradient[data.data_time_indexes[(ROW_RIGHT, i)]] += s[(ROW_RIGHT, i)] * constant_factor;
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// Compute diagonal elements of Hessian w.r.t lambda
|
|
|
|
let mut lambda_hessdiag: DVector<f64> = DVector::zeros(data.num_times());
|
2023-04-17 17:50:43 +10:00
|
|
|
for i in 0..data.num_obs() {
|
2023-04-28 01:02:09 +10:00
|
|
|
let denominator = s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
|
|
|
|
|
|
|
|
let aij_left = -s[(ROW_LEFT, i)] * exp_z_beta[i];
|
|
|
|
let aij_right = s[(ROW_RIGHT, i)] * exp_z_beta[i];
|
|
|
|
|
|
|
|
lambda_hessdiag[data.data_time_indexes[(ROW_LEFT, i)]] += (-aij_left * exp_z_beta[i]) / denominator - (aij_left / denominator).powi(2);
|
|
|
|
lambda_hessdiag[data.data_time_indexes[(ROW_RIGHT, i)]] += (-aij_right * exp_z_beta[i]) / denominator - (aij_right / denominator).powi(2);
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
|
2023-04-29 17:39:25 +10:00
|
|
|
// Here are the diagonal elements of G, being the negative diagonal elements of the Hessian
|
|
|
|
let mut lambda_neghessdiag_nonsingular = -lambda_hessdiag;
|
|
|
|
lambda_neghessdiag_nonsingular.apply(|v| *v = *v + 1e-9); // Add a small epsilon to ensure non-singular
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// To invert the diagonal matrix G, we simply have diag(1/diag(G))
|
2023-04-29 17:39:25 +10:00
|
|
|
let mut lambda_invneghessdiag = lambda_neghessdiag_nonsingular.clone();
|
|
|
|
lambda_invneghessdiag.apply(|v| *v = 1.0 / *v);
|
2023-04-28 01:02:09 +10:00
|
|
|
|
|
|
|
let lambda_nr_factors = lambda_invneghessdiag.component_mul(&lambda_gradient);
|
|
|
|
|
|
|
|
// Take as large a step as possible while the log-likelihood increases
|
|
|
|
let mut step_size_exponent: i32 = 0;
|
|
|
|
loop {
|
|
|
|
let step_size = 0.5_f64.powi(step_size_exponent);
|
|
|
|
let lambda_target = lambda + step_size * &lambda_nr_factors;
|
|
|
|
|
|
|
|
// Do projection step
|
2023-04-29 17:39:25 +10:00
|
|
|
let mut lambda_new = monotonic_regression_pava(lambda_target, lambda_neghessdiag_nonsingular.clone());
|
2023-04-28 01:02:09 +10:00
|
|
|
lambda_new.apply(|l| *l = l.max(0.0));
|
|
|
|
|
|
|
|
// Constrain Λ(0) = 0
|
|
|
|
lambda_new[0] = 0.0;
|
|
|
|
|
|
|
|
let s_new = compute_s(data, &lambda_new, exp_z_beta);
|
|
|
|
let log_likelihood_new = log_likelihood_obs(&s_new).sum();
|
|
|
|
|
|
|
|
if log_likelihood_new > log_likelihood {
|
|
|
|
return (lambda_new, s_new, log_likelihood_new);
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
2023-04-28 01:02:09 +10:00
|
|
|
|
|
|
|
step_size_exponent += 1;
|
|
|
|
|
|
|
|
if step_size_exponent > 10 {
|
|
|
|
// This shouldn't happen unless there is a numeric problem with the gradient/Hessian
|
|
|
|
panic!("ICM fails to increase log-likelihood");
|
|
|
|
//return (lambda.clone(), s.clone(), log_likelihood);
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
fn update_beta(data: &IntervalCensoredCoxData, beta: &DVector<f64>, lambda: &DVector<f64>, exp_z_beta: &DVector<f64>, s: &Matrix2xX<f64>) -> DVector<f64> {
|
2023-04-29 18:29:33 +10:00
|
|
|
// Compute gradient and Hessian w.r.t. beta
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut beta_gradient: DVector<f64> = DVector::zeros(data.num_covs());
|
2023-04-29 18:29:33 +10:00
|
|
|
let mut beta_hessian: DMatrix<f64> = DMatrix::zeros(data.num_covs(), data.num_covs());
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
for i in 0..data.num_obs() {
|
2023-04-29 18:29:33 +10:00
|
|
|
// TODO: Can this be vectorised? Seems unlikely however
|
2023-04-28 01:02:09 +10:00
|
|
|
let bli = s[(ROW_LEFT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]];
|
|
|
|
let bri = s[(ROW_RIGHT, i)] * exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]];
|
2023-04-29 18:29:33 +10:00
|
|
|
|
|
|
|
// Gradient
|
2023-04-28 01:02:09 +10:00
|
|
|
let z_factor = (bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)]);
|
2023-04-29 17:39:25 +10:00
|
|
|
beta_gradient.axpy(z_factor, &data.data_indep.column(i), 1.0); // beta_gradient += z_factor * data.data_indep.column(i);
|
2023-04-28 01:02:09 +10:00
|
|
|
|
2023-04-29 18:29:33 +10:00
|
|
|
// Hessian
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut z_factor = exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_RIGHT, i)]] * (s[(ROW_RIGHT, i)] - bri);
|
|
|
|
z_factor -= exp_z_beta[i] * lambda[data.data_time_indexes[(ROW_LEFT, i)]] * (s[(ROW_LEFT, i)] - bli);
|
|
|
|
z_factor /= s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)];
|
|
|
|
|
|
|
|
z_factor -= ((bri - bli) / (s[(ROW_LEFT, i)] - s[(ROW_RIGHT, i)])).powi(2);
|
|
|
|
|
2023-04-29 17:39:25 +10:00
|
|
|
beta_hessian.syger(z_factor, &data.data_indep.column(i), &data.data_indep.column(i), 1.0); // beta_hessian += z_factor * data.data_indep.column(i) * data.data_indep.column(i).transpose();
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut beta_neghess = -beta_hessian;
|
|
|
|
if !beta_neghess.try_inverse_mut() {
|
|
|
|
panic!("Hessian is not invertible");
|
|
|
|
}
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let beta_new = beta + beta_neghess * beta_gradient;
|
|
|
|
return beta_new;
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
fn profile_log_likelihood_obs(data: &IntervalCensoredCoxData, beta: DVector<f64>, mut lambda: DVector<f64>, max_iterations: u32, ll_tolerance: f64) -> DVector<f64> {
|
|
|
|
// -------------------
|
|
|
|
// Apply ICM algorithm
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let exp_z_beta = compute_exp_z_beta(&data, &beta);
|
|
|
|
let mut s = compute_s(&data, &lambda, &exp_z_beta);
|
|
|
|
let mut ll_model = log_likelihood_obs(&s).sum();
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut iteration = 1;
|
|
|
|
loop {
|
|
|
|
// Update lambda
|
|
|
|
let (lambda_new, ll_model_new);
|
|
|
|
(lambda_new, s, ll_model_new) = update_lambda(&data, &lambda, &exp_z_beta, &s, ll_model);
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
// [Do not update beta]
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
let mut converged = true;
|
|
|
|
if ll_model_new - ll_model > ll_tolerance {
|
|
|
|
converged = false;
|
|
|
|
}
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
lambda = lambda_new;
|
|
|
|
ll_model = ll_model_new;
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
if converged {
|
|
|
|
return log_likelihood_obs(&s);
|
|
|
|
}
|
2023-04-17 17:50:43 +10:00
|
|
|
|
2023-04-28 01:02:09 +10:00
|
|
|
iteration += 1;
|
|
|
|
if iteration > max_iterations {
|
|
|
|
panic!("Exceeded --max-iterations");
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Serialize, Deserialize)]
|
2023-04-17 22:12:07 +10:00
|
|
|
pub struct IntervalCensoredCoxResult {
|
|
|
|
pub params: Vec<f64>,
|
|
|
|
pub params_se: Vec<f64>,
|
2023-04-18 16:18:19 +10:00
|
|
|
pub cumulative_hazard: Vec<f64>,
|
|
|
|
pub cumulative_hazard_times: Vec<f64>,
|
2023-04-17 22:12:07 +10:00
|
|
|
pub ll_model: f64,
|
|
|
|
pub ll_null: f64,
|
2023-04-17 17:50:43 +10:00
|
|
|
}
|