diff --git a/Cargo.lock b/Cargo.lock index 8689cf0..4c99a24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -313,6 +313,7 @@ version = "0.1.0" dependencies = [ "clap", "console", + "csv", "indicatif", "nalgebra", "prettytable-rs", diff --git a/Cargo.toml b/Cargo.toml index 1246e78..72d1c63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] clap = { version = "4.2.1", features = ["derive"] } console = "0.15.5" +csv = "1.2.1" indicatif = {version = "0.17.3", features = ["rayon"]} nalgebra = "0.32.2" prettytable-rs = "0.10.0" diff --git a/src/intcox.rs b/src/intcox.rs index 05c68dc..861c561 100644 --- a/src/intcox.rs +++ b/src/intcox.rs @@ -16,10 +16,10 @@ const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64 -use std::fs; use std::io; use clap::{Args, ValueEnum}; +use csv::{Reader, StringRecord}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; use nalgebra::{DMatrix, DVector, Matrix1xX}; use prettytable::{Table, format, row}; @@ -58,46 +58,8 @@ enum OutputFormat { } pub fn main(args: IntCoxArgs) { - let lines: Vec; - if args.input == "-" { - lines = io::stdin().lines().map(|l| l.unwrap()).collect(); - } else { - let contents = fs::read_to_string(args.input).unwrap(); - lines = contents.trim_end().split("\n").map(|s| s.to_string()).collect(); - } - - // Read data into matrices - - let mut data_times: DMatrix = DMatrix::zeros( - 2, // Left time, right time - lines.len() - 1 // Minus 1 row for header row - ); - - // Called "Z" in the paper and "X" in the C++ code - let mut data_indep: DMatrix = DMatrix::zeros( - lines[0].split(",").count() - 2, - lines.len() - 1 // Minus 1 row for header row - ); - - // Read header row - let indep_names: Vec<&str> = lines[0].split(",").skip(2).collect(); - // Read data - // FIXME: Parse CSV more robustly - for (i, row) in lines.iter().skip(1).enumerate() { - for (j, item) in row.split(",").enumerate() { - let value = match item { - "inf" => f64::INFINITY, - _ => item.parse().expect("Malformed float") - }; - - if j < 2 { - data_times[(j, i)] = value; - } else { - data_indep[(j - 2, i)] = value; - } - } - } + let (indep_names, data_times, data_indep) = read_data(&args.input); // Fit regression let progress_bar = ProgressBar::with_draw_target(Some(0), ProgressDrawTarget::term_like(Box::new(UnconditionalTermLike::stderr()))); @@ -137,6 +99,55 @@ pub fn main(args: IntCoxArgs) { } } +pub fn read_data(path: &str) -> (Vec, DMatrix, DMatrix) { + // Read CSV into memory + let headers: StringRecord; + let records: Vec; + if path == "-" { + let mut csv_reader = Reader::from_reader(io::stdin()); + headers = csv_reader.headers().unwrap().clone(); + records = csv_reader.records().map(|r| r.unwrap()).collect(); + } else { + let mut csv_reader = Reader::from_path(path).unwrap(); + headers = csv_reader.headers().unwrap().clone(); + records = csv_reader.records().map(|r| r.unwrap()).collect(); + } + + // Read data into matrices + + let mut data_times: DMatrix = DMatrix::zeros( + 2, // Left time, right time + records.len() + ); + + // Called "Z" in the paper and "X" in the C++ code + let mut data_indep: DMatrix = DMatrix::zeros( + headers.len() - 2, + records.len() + ); + + // Parse header row + let indep_names: Vec = headers.iter().skip(2).map(String::from).collect(); + + // Parse data + for (i, row) in records.iter().enumerate() { + for (j, item) in row.iter().enumerate() { + let value = match item { + "inf" => f64::INFINITY, + _ => item.parse().expect("Malformed float") + }; + + if j < 2 { + data_times[(j, i)] = value; + } else { + data_indep[(j - 2, i)] = value; + } + } + } + + return (indep_names, data_times, data_indep); +} + struct IntervalCensoredCoxData { data_times: DMatrix, data_indep: DMatrix, diff --git a/tests/intcox.rs b/tests/intcox.rs index fb7974d..626da22 100644 --- a/tests/intcox.rs +++ b/tests/intcox.rs @@ -14,54 +14,20 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use std::fs; - use indicatif::ProgressBar; -use nalgebra::DMatrix; -use hpstat::intcox::fit_interval_censored_cox; +use hpstat::intcox; #[test] fn test_intcox_zeng_mao_lin() { // Compare "Bangkok Metropolitan Administration HIV" data from Zeng, Mao & Lin (2016) with Stata 17 output - let contents = fs::read_to_string("tests/zeng_mao_lin.csv").unwrap(); - let lines: Vec = contents.trim_end().split("\n").map(|s| s.to_string()).collect(); - - // Read data into matrices - - let mut data_times: DMatrix = DMatrix::zeros( - 2, // Left time, right time - lines.len() - 1 // Minus 1 row for header row - ); - - // Called "Z" in the paper and "X" in the C++ code - let mut data_indep: DMatrix = DMatrix::zeros( - lines[0].split(",").count() - 2, - lines.len() - 1 // Minus 1 row for header row - ); - - // Read data - // FIXME: Parse CSV more robustly - for (i, row) in lines.iter().skip(1).enumerate() { - for (j, item) in row.split(",").enumerate() { - let value = match item { - "inf" => f64::INFINITY, - _ => item.parse().expect("Malformed float") - }; - - if j < 2 { - data_times[(j, i)] = value; - } else { - data_indep[(j - 2, i)] = value; - } - } - } + let (_indep_names, data_times, data_indep) = intcox::read_data("tests/zeng_mao_lin.csv"); // Fit regression let progress_bar = ProgressBar::hidden(); //let result = fit_interval_censored_cox(data_times, data_indep, 200, 0.00005, false, progress_bar); - let result = fit_interval_censored_cox(data_times, data_indep, 100, 0.0001, false, progress_bar); + let result = intcox::fit_interval_censored_cox(data_times, data_indep, 100, 0.0001, false, progress_bar); // import delimited "zeng_mao_lin.csv", case(preserve) numericcols(2) // stintcox Needle Needle2 LogAge GenderM RaceO RaceW GenderM_RaceO GenderM_RaceW, interval(Left_Time Right_Time) full nohr favorspeed lrmodel