diff --git a/tests/sas.csv b/tests/sas.csv new file mode 100644 index 0000000..2af1b92 --- /dev/null +++ b/tests/sas.csv @@ -0,0 +1,47 @@ +LTime,RTime +45,inf +25,37 +37,inf +6,10 +46,inf +0,5 +0,7 +26,40 +18,inf +46,inf +46,inf +24,inf +46,inf +27,34 +36,inf +7,16 +36,44 +5,11 +17,inf +46,inf +19,35 +7,14 +36,48 +17,25 +37,44 +37,inf +24,inf +0,8 +40,inf +32,inf +4,11 +17,25 +33,inf +15,inf +46,inf +19,26 +11,15 +11,18 +37,inf +22,inf +38,inf +34,inf +46,inf +5,12 +36,inf +46,inf diff --git a/tests/turnbull.rs b/tests/turnbull.rs index 8a92cab..a414961 100644 --- a/tests/turnbull.rs +++ b/tests/turnbull.rs @@ -26,7 +26,7 @@ fn test_turnbull_minitab() { // Fit regression let progress_bar = ProgressBar::hidden(); - let result = turnbull::fit_turnbull(data_times, progress_bar, 500, 0.01, turnbull::SEMethod::OIM, 0.0001); + let result = turnbull::fit_turnbull(data_times, progress_bar, 500, 0.01, turnbull::SEMethod::OIM, 0.0001, 0.01); assert_eq!(result.failure_intervals[0], (20000.0, 30000.0)); assert_eq!(result.failure_intervals[1], (30000.0, 40000.0)); @@ -64,6 +64,46 @@ fn test_turnbull_minitab() { assert!(abs_diff(survival_prob_se[6], 0.0123546) < 0.0000001); } +#[test] +fn test_turnbull_sas() { + // Compare "RT" example with SAS output + + let data_times = turnbull::read_data("tests/sas.csv"); + + // Fit regression + let progress_bar = ProgressBar::hidden(); + let result = turnbull::fit_turnbull(data_times, progress_bar, 500, 0.0001, turnbull::SEMethod::None, 0.0001, 0.01); + + assert_eq!(result.failure_intervals[0], (4.0, 5.0)); + assert_eq!(result.failure_intervals[1], (6.0, 7.0)); + assert_eq!(result.failure_intervals[2], (7.0, 8.0)); + assert_eq!(result.failure_intervals[3], (11.0, 12.0)); + assert_eq!(result.failure_intervals[4], (15.0, 16.0)); // Not in SAS + assert_eq!(result.failure_intervals[5], (17.0, 18.0)); // Not in SAS + assert_eq!(result.failure_intervals[6], (24.0, 25.0)); + assert_eq!(result.failure_intervals[7], (25.0, 26.0)); // Not in SAS + assert_eq!(result.failure_intervals[8], (33.0, 34.0)); + assert_eq!(result.failure_intervals[9], (34.0, 35.0)); // Not in SAS + assert_eq!(result.failure_intervals[10], (36.0, 37.0)); // Not in SAS + assert_eq!(result.failure_intervals[11], (38.0, 40.0)); + assert_eq!(result.failure_intervals[12], (40.0, 44.0)); // Not in SAS + assert_eq!(result.failure_intervals[13], (46.0, 48.0)); + + assert!(abs_diff(result.survival_prob[0], 0.9537) < 0.0001); + assert!(abs_diff(result.survival_prob[1], 0.9203) < 0.0001); + assert!(abs_diff(result.survival_prob[2], 0.8316) < 0.0001); + assert!(abs_diff(result.survival_prob[3], 0.7609) < 0.0001); + assert!(abs_diff(result.survival_prob[4], 0.7609) < 0.0001); + assert!(abs_diff(result.survival_prob[5], 0.7609) < 0.0001); + assert!(abs_diff(result.survival_prob[6], 0.6682) < 0.0001); + assert!(abs_diff(result.survival_prob[7], 0.6682) < 0.0001); + assert!(abs_diff(result.survival_prob[8], 0.5864) < 0.0001); + assert!(abs_diff(result.survival_prob[9], 0.5864) < 0.0001); + assert!(abs_diff(result.survival_prob[10], 0.5864) < 0.0001); + assert!(abs_diff(result.survival_prob[11], 0.4656) < 0.0001); + assert!(abs_diff(result.survival_prob[12], 0.4656) < 0.0001); +} + fn abs_diff(a: f64, b: f64) -> f64 { return (a - b).abs(); }