Add test for intcox
This commit is contained in:
parent
46e3b189ce
commit
6ac2d9f055
@ -12,6 +12,9 @@ rayon = "1.7.0"
|
|||||||
serde = { version = "1.0.160", features = ["derive"] }
|
serde = { version = "1.0.160", features = ["derive"] }
|
||||||
serde_json = "1.0.96"
|
serde_json = "1.0.96"
|
||||||
|
|
||||||
|
[profile.test]
|
||||||
|
opt-level = 3
|
||||||
|
|
||||||
[profile.perf]
|
[profile.perf]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
debug = true
|
debug = true
|
||||||
|
@ -162,7 +162,7 @@ impl IntervalCensoredCoxData {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult {
|
pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatrix<f64>, max_iterations: u32, tolerance: f64, reduced: bool, progress_bar: ProgressBar) -> IntervalCensoredCoxResult {
|
||||||
// ----------------------
|
// ----------------------
|
||||||
// Prepare for regression
|
// Prepare for regression
|
||||||
|
|
||||||
@ -572,11 +572,11 @@ fn profile_log_likelihood_obs(data: &IntervalCensoredCoxData, beta: DVector<f64>
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct IntervalCensoredCoxResult {
|
pub struct IntervalCensoredCoxResult {
|
||||||
params: Vec<f64>,
|
pub params: Vec<f64>,
|
||||||
params_se: Vec<f64>,
|
pub params_se: Vec<f64>,
|
||||||
ll_model: f64,
|
pub ll_model: f64,
|
||||||
ll_null: f64,
|
pub ll_null: f64,
|
||||||
// TODO: cumulative hazard, etc.
|
// TODO: cumulative hazard, etc.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
1
src/lib.rs
Normal file
1
src/lib.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod intcox;
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
mod intcox;
|
use hpstat::intcox;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(about="High-performance statistics implementations")]
|
#[command(about="High-performance statistics implementations")]
|
||||||
|
87
tests/intcox.rs
Normal file
87
tests/intcox.rs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// hpstat: High-performance statistics implementations
|
||||||
|
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
use indicatif::ProgressBar;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
|
use hpstat::intcox::fit_interval_censored_cox;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_intcox_zeng_mao_lin() {
|
||||||
|
// Compare "Bangkok Metropolitan Administration HIV" data from Zeng, Mao & Lin (2016) with IntCens 0.2 output
|
||||||
|
|
||||||
|
let contents = fs::read_to_string("tests/zeng_mao_lin.csv").unwrap();
|
||||||
|
let lines: Vec<String> = contents.trim_end().split("\n").map(|s| s.to_string()).collect();
|
||||||
|
|
||||||
|
// Read data into matrices
|
||||||
|
|
||||||
|
let mut data_times: DMatrix<f64> = DMatrix::zeros(
|
||||||
|
2, // Left time, right time
|
||||||
|
lines.len() - 1 // Minus 1 row for header row
|
||||||
|
);
|
||||||
|
|
||||||
|
// Called "Z" in the paper and "X" in the C++ code
|
||||||
|
let mut data_indep: DMatrix<f64> = DMatrix::zeros(
|
||||||
|
lines[0].split(",").count() - 2,
|
||||||
|
lines.len() - 1 // Minus 1 row for header row
|
||||||
|
);
|
||||||
|
|
||||||
|
// Read data
|
||||||
|
// FIXME: Parse CSV more robustly
|
||||||
|
for (i, row) in lines.iter().skip(1).enumerate() {
|
||||||
|
for (j, item) in row.split(",").enumerate() {
|
||||||
|
let value = match item {
|
||||||
|
"inf" => f64::INFINITY,
|
||||||
|
_ => item.parse().expect("Malformed float")
|
||||||
|
};
|
||||||
|
|
||||||
|
if j < 2 {
|
||||||
|
data_times[(j, i)] = value;
|
||||||
|
} else {
|
||||||
|
data_indep[(j - 2, i)] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fit regression
|
||||||
|
let progress_bar = ProgressBar::hidden();
|
||||||
|
//let result = fit_interval_censored_cox(data_times, data_indep, 200, 0.00005, false, progress_bar);
|
||||||
|
let result = fit_interval_censored_cox(data_times, data_indep, 100, 0.0001, false, progress_bar);
|
||||||
|
|
||||||
|
// ./unireg --in zeng_mao_lin.csv --out out.txt --r 0.0 --model "(Left_Time, Right_Time) = Needle + Needle2 + LogAge + GenderM + RaceO + RaceW + GenderM_RaceO + GenderM_RaceW" --sep , --inf_char inf --convergence_threshold 0.002
|
||||||
|
|
||||||
|
assert!((result.ll_model - -603.205).abs() < 1.0);
|
||||||
|
|
||||||
|
assert!((result.params[0] - -0.18636961816695094).abs() < 0.01);
|
||||||
|
assert!((result.params[1] - 0.080478699024478656).abs() < 0.01);
|
||||||
|
assert!((result.params[2] - -0.71260450817296639).abs() < 0.01);
|
||||||
|
assert!((result.params[3] - -0.22937443803422858).abs() < 0.01);
|
||||||
|
assert!((result.params[4] - -0.14101449484871434).abs() < 0.01);
|
||||||
|
assert!((result.params[5] - -0.43894526362102332).abs() < 0.01);
|
||||||
|
assert!((result.params[6] - 0.064533885082884768).abs() < 0.01);
|
||||||
|
assert!((result.params[7] - 0.20970425315378016).abs() < 0.01);
|
||||||
|
|
||||||
|
assert!((result.params_se[0] - 0.41496954829036448).abs() < 0.01);
|
||||||
|
assert!((result.params_se[1] - 0.15086156546712554).abs() < 0.01);
|
||||||
|
assert!((result.params_se[2] - 0.36522062865858951).abs() < 0.01);
|
||||||
|
assert!((result.params_se[3] - 0.32195496906604004).abs() < 0.01);
|
||||||
|
assert!((result.params_se[4] - 0.3912241733944129).abs() < 0.01);
|
||||||
|
assert!((result.params_se[5] - 0.41907763222198746).abs() < 0.01);
|
||||||
|
assert!((result.params_se[6] - 0.45849947730170948).abs() < 0.01);
|
||||||
|
assert!((result.params_se[7] - 0.48803508171247434).abs() < 0.01);
|
||||||
|
}
|
1125
tests/zeng_mao_lin.csv
Normal file
1125
tests/zeng_mao_lin.csv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user