2023-04-17 22:12:07 +10:00
|
|
|
// hpstat: High-performance statistics implementations
|
|
|
|
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
|
|
|
|
//
|
|
|
|
// This program is free software: you can redistribute it and/or modify
|
|
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
|
|
// (at your option) any later version.
|
|
|
|
//
|
|
|
|
// This program is distributed in the hope that it will be useful,
|
|
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
// GNU Affero General Public License for more details.
|
|
|
|
//
|
|
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
|
|
use std::fs;
|
|
|
|
|
|
|
|
use indicatif::ProgressBar;
|
|
|
|
use nalgebra::DMatrix;
|
|
|
|
|
|
|
|
use hpstat::intcox::fit_interval_censored_cox;
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_intcox_zeng_mao_lin() {
|
2023-04-18 16:18:19 +10:00
|
|
|
// Compare "Bangkok Metropolitan Administration HIV" data from Zeng, Mao & Lin (2016) with Stata 17 output
|
2023-04-17 22:12:07 +10:00
|
|
|
|
|
|
|
let contents = fs::read_to_string("tests/zeng_mao_lin.csv").unwrap();
|
|
|
|
let lines: Vec<String> = contents.trim_end().split("\n").map(|s| s.to_string()).collect();
|
|
|
|
|
|
|
|
// Read data into matrices
|
|
|
|
|
|
|
|
let mut data_times: DMatrix<f64> = DMatrix::zeros(
|
|
|
|
2, // Left time, right time
|
|
|
|
lines.len() - 1 // Minus 1 row for header row
|
|
|
|
);
|
|
|
|
|
|
|
|
// Called "Z" in the paper and "X" in the C++ code
|
|
|
|
let mut data_indep: DMatrix<f64> = DMatrix::zeros(
|
|
|
|
lines[0].split(",").count() - 2,
|
|
|
|
lines.len() - 1 // Minus 1 row for header row
|
|
|
|
);
|
|
|
|
|
|
|
|
// Read data
|
|
|
|
// FIXME: Parse CSV more robustly
|
|
|
|
for (i, row) in lines.iter().skip(1).enumerate() {
|
|
|
|
for (j, item) in row.split(",").enumerate() {
|
|
|
|
let value = match item {
|
|
|
|
"inf" => f64::INFINITY,
|
|
|
|
_ => item.parse().expect("Malformed float")
|
|
|
|
};
|
|
|
|
|
|
|
|
if j < 2 {
|
|
|
|
data_times[(j, i)] = value;
|
|
|
|
} else {
|
|
|
|
data_indep[(j - 2, i)] = value;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fit regression
|
|
|
|
let progress_bar = ProgressBar::hidden();
|
|
|
|
//let result = fit_interval_censored_cox(data_times, data_indep, 200, 0.00005, false, progress_bar);
|
|
|
|
let result = fit_interval_censored_cox(data_times, data_indep, 100, 0.0001, false, progress_bar);
|
|
|
|
|
2023-04-18 16:18:19 +10:00
|
|
|
// import delimited "zeng_mao_lin.csv", case(preserve) numericcols(2)
|
|
|
|
// stintcox Needle Needle2 LogAge GenderM RaceO RaceW GenderM_RaceO GenderM_RaceW, interval(Left_Time Right_Time) full nohr favorspeed lrmodel
|
|
|
|
// stcurve, cumhaz outfile("cumhaz.dta")
|
2023-04-17 22:12:07 +10:00
|
|
|
|
2023-04-18 16:18:19 +10:00
|
|
|
assert!(rel_diff(result.ll_model, -604.82642) < 0.01);
|
|
|
|
assert!(rel_diff(result.ll_null, -608.64263) < 0.01);
|
2023-04-17 22:12:07 +10:00
|
|
|
|
2023-04-18 16:18:19 +10:00
|
|
|
assert!(rel_diff(result.params[0], -0.1869297) < 0.01);
|
|
|
|
assert!(rel_diff(result.params[1], 0.0808377) < 0.01);
|
|
|
|
assert!(rel_diff(result.params[2], -0.7088894) < 0.01);
|
|
|
|
assert!(rel_diff(result.params[3], -0.2296864) < 0.01);
|
|
|
|
assert!(rel_diff(result.params[4], -0.1408832) < 0.01);
|
|
|
|
assert!(rel_diff(result.params[5], -0.4397316) < 0.01);
|
|
|
|
assert!(rel_diff(result.params[6], 0.0642637) < 0.01);
|
|
|
|
assert!(rel_diff(result.params[7], 0.2110733) < 0.01);
|
2023-04-17 22:12:07 +10:00
|
|
|
|
2023-04-18 16:18:19 +10:00
|
|
|
assert!(rel_diff(result.params_se[0], 0.4148436) < 0.01);
|
|
|
|
assert!(rel_diff(result.params_se[1], 0.1507537) < 0.01);
|
|
|
|
assert!(rel_diff(result.params_se[2], 0.3653805) < 0.01);
|
|
|
|
assert!(rel_diff(result.params_se[3], 0.3214563) < 0.01);
|
|
|
|
assert!(rel_diff(result.params_se[4], 0.3889668) < 0.01);
|
|
|
|
assert!(rel_diff(result.params_se[5], 0.4165912) < 0.01);
|
|
|
|
assert!(rel_diff(result.params_se[6], 0.4557368) < 0.01);
|
|
|
|
assert!(rel_diff(result.params_se[7], 0.4853911) < 0.01);
|
|
|
|
|
|
|
|
// Check a few points on the cumulative hazard curve
|
|
|
|
assert_eq!(result.cumulative_hazard_times[0], 0.0);
|
|
|
|
assert_eq!(result.cumulative_hazard[0], 0.0);
|
|
|
|
assert!(abs_diff(result.cumulative_hazard_times[10], 3.43757) < 0.000001);
|
|
|
|
assert!(rel_diff(result.cumulative_hazard[10], 0.01913) < 0.1);
|
|
|
|
assert!(abs_diff(result.cumulative_hazard_times[30], 3.710771) < 0.000001);
|
|
|
|
assert!(rel_diff(result.cumulative_hazard[30], 0.0282363) < 0.1);
|
|
|
|
assert!(abs_diff(result.cumulative_hazard_times[80], 4.277966) < 0.000001);
|
|
|
|
assert!(rel_diff(result.cumulative_hazard[80], 0.038723) < 0.1);
|
|
|
|
assert!(abs_diff(result.cumulative_hazard_times[180], 8.566904) < 0.000001);
|
|
|
|
assert!(rel_diff(result.cumulative_hazard[180], 0.0564792) < 0.1);
|
|
|
|
assert!(abs_diff(result.cumulative_hazard_times[380], 19.61333) < 0.00001);
|
|
|
|
assert!(rel_diff(result.cumulative_hazard[380], 0.1084475) < 0.1);
|
|
|
|
assert!(abs_diff(result.cumulative_hazard_times[880], 28.87403) < 0.00001);
|
|
|
|
assert!(rel_diff(result.cumulative_hazard[880], 0.1348967) < 0.1);
|
|
|
|
assert!(abs_diff(*result.cumulative_hazard_times.last().unwrap(), 42.78283) < 0.00001);
|
|
|
|
assert!(rel_diff(*result.cumulative_hazard.last().unwrap(), 0.1638222) < 0.1);
|
|
|
|
}
|
|
|
|
|
|
|
|
fn abs_diff(a: f64, b: f64) -> f64 {
|
|
|
|
return (a - b).abs();
|
|
|
|
}
|
|
|
|
|
|
|
|
fn rel_diff(a: f64, b: f64) -> f64 {
|
|
|
|
return ((a - b) / b).abs();
|
2023-04-17 22:12:07 +10:00
|
|
|
}
|