Also report and unit test cumulative hazard curve
This commit is contained in:
parent
6ac2d9f055
commit
9893657bb0
@ -339,6 +339,8 @@ pub fn fit_interval_censored_cox(data_times: DMatrix<f64>, mut data_indep: DMatr
|
||||
return IntervalCensoredCoxResult {
|
||||
params: beta_unstandardised.data.as_vec().clone(),
|
||||
params_se: beta_se_unstandardised.data.as_vec().clone(),
|
||||
cumulative_hazard: cumulative_hazard(&lambda).data.as_vec().clone(),
|
||||
cumulative_hazard_times: data.time_points,
|
||||
ll_model: ll_model,
|
||||
ll_null: ll_null,
|
||||
};
|
||||
@ -575,6 +577,8 @@ fn profile_log_likelihood_obs(data: &IntervalCensoredCoxData, beta: DVector<f64>
|
||||
pub struct IntervalCensoredCoxResult {
|
||||
pub params: Vec<f64>,
|
||||
pub params_se: Vec<f64>,
|
||||
pub cumulative_hazard: Vec<f64>,
|
||||
pub cumulative_hazard_times: Vec<f64>,
|
||||
pub ll_model: f64,
|
||||
pub ll_null: f64,
|
||||
// TODO: cumulative hazard, etc.
|
||||
|
@ -23,7 +23,7 @@ use hpstat::intcox::fit_interval_censored_cox;
|
||||
|
||||
#[test]
|
||||
fn test_intcox_zeng_mao_lin() {
|
||||
// Compare "Bangkok Metropolitan Administration HIV" data from Zeng, Mao & Lin (2016) with IntCens 0.2 output
|
||||
// Compare "Bangkok Metropolitan Administration HIV" data from Zeng, Mao & Lin (2016) with Stata 17 output
|
||||
|
||||
let contents = fs::read_to_string("tests/zeng_mao_lin.csv").unwrap();
|
||||
let lines: Vec<String> = contents.trim_end().split("\n").map(|s| s.to_string()).collect();
|
||||
@ -63,25 +63,54 @@ fn test_intcox_zeng_mao_lin() {
|
||||
//let result = fit_interval_censored_cox(data_times, data_indep, 200, 0.00005, false, progress_bar);
|
||||
let result = fit_interval_censored_cox(data_times, data_indep, 100, 0.0001, false, progress_bar);
|
||||
|
||||
// ./unireg --in zeng_mao_lin.csv --out out.txt --r 0.0 --model "(Left_Time, Right_Time) = Needle + Needle2 + LogAge + GenderM + RaceO + RaceW + GenderM_RaceO + GenderM_RaceW" --sep , --inf_char inf --convergence_threshold 0.002
|
||||
// import delimited "zeng_mao_lin.csv", case(preserve) numericcols(2)
|
||||
// stintcox Needle Needle2 LogAge GenderM RaceO RaceW GenderM_RaceO GenderM_RaceW, interval(Left_Time Right_Time) full nohr favorspeed lrmodel
|
||||
// stcurve, cumhaz outfile("cumhaz.dta")
|
||||
|
||||
assert!((result.ll_model - -603.205).abs() < 1.0);
|
||||
assert!(rel_diff(result.ll_model, -604.82642) < 0.01);
|
||||
assert!(rel_diff(result.ll_null, -608.64263) < 0.01);
|
||||
|
||||
assert!((result.params[0] - -0.18636961816695094).abs() < 0.01);
|
||||
assert!((result.params[1] - 0.080478699024478656).abs() < 0.01);
|
||||
assert!((result.params[2] - -0.71260450817296639).abs() < 0.01);
|
||||
assert!((result.params[3] - -0.22937443803422858).abs() < 0.01);
|
||||
assert!((result.params[4] - -0.14101449484871434).abs() < 0.01);
|
||||
assert!((result.params[5] - -0.43894526362102332).abs() < 0.01);
|
||||
assert!((result.params[6] - 0.064533885082884768).abs() < 0.01);
|
||||
assert!((result.params[7] - 0.20970425315378016).abs() < 0.01);
|
||||
assert!(rel_diff(result.params[0], -0.1869297) < 0.01);
|
||||
assert!(rel_diff(result.params[1], 0.0808377) < 0.01);
|
||||
assert!(rel_diff(result.params[2], -0.7088894) < 0.01);
|
||||
assert!(rel_diff(result.params[3], -0.2296864) < 0.01);
|
||||
assert!(rel_diff(result.params[4], -0.1408832) < 0.01);
|
||||
assert!(rel_diff(result.params[5], -0.4397316) < 0.01);
|
||||
assert!(rel_diff(result.params[6], 0.0642637) < 0.01);
|
||||
assert!(rel_diff(result.params[7], 0.2110733) < 0.01);
|
||||
|
||||
assert!((result.params_se[0] - 0.41496954829036448).abs() < 0.01);
|
||||
assert!((result.params_se[1] - 0.15086156546712554).abs() < 0.01);
|
||||
assert!((result.params_se[2] - 0.36522062865858951).abs() < 0.01);
|
||||
assert!((result.params_se[3] - 0.32195496906604004).abs() < 0.01);
|
||||
assert!((result.params_se[4] - 0.3912241733944129).abs() < 0.01);
|
||||
assert!((result.params_se[5] - 0.41907763222198746).abs() < 0.01);
|
||||
assert!((result.params_se[6] - 0.45849947730170948).abs() < 0.01);
|
||||
assert!((result.params_se[7] - 0.48803508171247434).abs() < 0.01);
|
||||
assert!(rel_diff(result.params_se[0], 0.4148436) < 0.01);
|
||||
assert!(rel_diff(result.params_se[1], 0.1507537) < 0.01);
|
||||
assert!(rel_diff(result.params_se[2], 0.3653805) < 0.01);
|
||||
assert!(rel_diff(result.params_se[3], 0.3214563) < 0.01);
|
||||
assert!(rel_diff(result.params_se[4], 0.3889668) < 0.01);
|
||||
assert!(rel_diff(result.params_se[5], 0.4165912) < 0.01);
|
||||
assert!(rel_diff(result.params_se[6], 0.4557368) < 0.01);
|
||||
assert!(rel_diff(result.params_se[7], 0.4853911) < 0.01);
|
||||
|
||||
// Check a few points on the cumulative hazard curve
|
||||
assert_eq!(result.cumulative_hazard_times[0], 0.0);
|
||||
assert_eq!(result.cumulative_hazard[0], 0.0);
|
||||
assert!(abs_diff(result.cumulative_hazard_times[10], 3.43757) < 0.000001);
|
||||
assert!(rel_diff(result.cumulative_hazard[10], 0.01913) < 0.1);
|
||||
assert!(abs_diff(result.cumulative_hazard_times[30], 3.710771) < 0.000001);
|
||||
assert!(rel_diff(result.cumulative_hazard[30], 0.0282363) < 0.1);
|
||||
assert!(abs_diff(result.cumulative_hazard_times[80], 4.277966) < 0.000001);
|
||||
assert!(rel_diff(result.cumulative_hazard[80], 0.038723) < 0.1);
|
||||
assert!(abs_diff(result.cumulative_hazard_times[180], 8.566904) < 0.000001);
|
||||
assert!(rel_diff(result.cumulative_hazard[180], 0.0564792) < 0.1);
|
||||
assert!(abs_diff(result.cumulative_hazard_times[380], 19.61333) < 0.00001);
|
||||
assert!(rel_diff(result.cumulative_hazard[380], 0.1084475) < 0.1);
|
||||
assert!(abs_diff(result.cumulative_hazard_times[880], 28.87403) < 0.00001);
|
||||
assert!(rel_diff(result.cumulative_hazard[880], 0.1348967) < 0.1);
|
||||
assert!(abs_diff(*result.cumulative_hazard_times.last().unwrap(), 42.78283) < 0.00001);
|
||||
assert!(rel_diff(*result.cumulative_hazard.last().unwrap(), 0.1638222) < 0.1);
|
||||
}
|
||||
|
||||
fn abs_diff(a: f64, b: f64) -> f64 {
|
||||
return (a - b).abs();
|
||||
}
|
||||
|
||||
fn rel_diff(a: f64, b: f64) -> f64 {
|
||||
return ((a - b) / b).abs();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user