// hpstat: High-performance statistics implementations // Copyright © 2023 Lee Yingtong Li (RunasSudo) // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Affero General Public License for more details. // // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . use std::fs; use indicatif::ProgressBar; use nalgebra::DMatrix; use hpstat::intcox::fit_interval_censored_cox; #[test] fn test_intcox_zeng_mao_lin() { // Compare "Bangkok Metropolitan Administration HIV" data from Zeng, Mao & Lin (2016) with IntCens 0.2 output let contents = fs::read_to_string("tests/zeng_mao_lin.csv").unwrap(); let lines: Vec = contents.trim_end().split("\n").map(|s| s.to_string()).collect(); // Read data into matrices let mut data_times: DMatrix = DMatrix::zeros( 2, // Left time, right time lines.len() - 1 // Minus 1 row for header row ); // Called "Z" in the paper and "X" in the C++ code let mut data_indep: DMatrix = DMatrix::zeros( lines[0].split(",").count() - 2, lines.len() - 1 // Minus 1 row for header row ); // Read data // FIXME: Parse CSV more robustly for (i, row) in lines.iter().skip(1).enumerate() { for (j, item) in row.split(",").enumerate() { let value = match item { "inf" => f64::INFINITY, _ => item.parse().expect("Malformed float") }; if j < 2 { data_times[(j, i)] = value; } else { data_indep[(j - 2, i)] = value; } } } // Fit regression let progress_bar = ProgressBar::hidden(); //let result = fit_interval_censored_cox(data_times, data_indep, 200, 0.00005, false, progress_bar); let result = fit_interval_censored_cox(data_times, data_indep, 100, 0.0001, false, progress_bar); // ./unireg --in zeng_mao_lin.csv --out out.txt --r 0.0 --model "(Left_Time, Right_Time) = Needle + Needle2 + LogAge + GenderM + RaceO + RaceW + GenderM_RaceO + GenderM_RaceW" --sep , --inf_char inf --convergence_threshold 0.002 assert!((result.ll_model - -603.205).abs() < 1.0); assert!((result.params[0] - -0.18636961816695094).abs() < 0.01); assert!((result.params[1] - 0.080478699024478656).abs() < 0.01); assert!((result.params[2] - -0.71260450817296639).abs() < 0.01); assert!((result.params[3] - -0.22937443803422858).abs() < 0.01); assert!((result.params[4] - -0.14101449484871434).abs() < 0.01); assert!((result.params[5] - -0.43894526362102332).abs() < 0.01); assert!((result.params[6] - 0.064533885082884768).abs() < 0.01); assert!((result.params[7] - 0.20970425315378016).abs() < 0.01); assert!((result.params_se[0] - 0.41496954829036448).abs() < 0.01); assert!((result.params_se[1] - 0.15086156546712554).abs() < 0.01); assert!((result.params_se[2] - 0.36522062865858951).abs() < 0.01); assert!((result.params_se[3] - 0.32195496906604004).abs() < 0.01); assert!((result.params_se[4] - 0.3912241733944129).abs() < 0.01); assert!((result.params_se[5] - 0.41907763222198746).abs() < 0.01); assert!((result.params_se[6] - 0.45849947730170948).abs() < 0.01); assert!((result.params_se[7] - 0.48803508171247434).abs() < 0.01); }