turnbull: Custom CSV implementation

Avoid unnecessary String allocation
13% speedup
This commit is contained in:
RunasSudo 2023-11-11 00:25:19 +11:00
parent b691c5a8d7
commit 8914cf3507
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
3 changed files with 145 additions and 35 deletions

133
src/csv.rs Normal file
View File

@ -0,0 +1,133 @@
// hpstat: High-performance statistics implementations
// Copyright © 2023 Lee Yingtong Li (RunasSudo)
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
use std::io::BufRead;
pub fn read_csv<R: BufRead>(mut reader: R) -> (Vec<String>, Vec<f64>) {
// This custom CSV parser is faster than the csv library because we do not waste time allocating Strings for the data which will inevitably be parsed to float anyway
// Reuse a single buffer to avoid unnecessary allocations
// Since we need to make copies only for the headers - the data are directly parsed to float
let mut buffer = String::new();
// Read header
let headers = read_row_as_strings(&mut reader, &mut buffer);
// Read data
let mut data = Vec::new();
let mut row = Vec::new();
loop {
if read_row_as_floats(&mut reader, &mut buffer, &mut row) {
if row.len() != headers.len() {
panic!("Expected row of {} entries, got {} entries", headers.len(), row.len());
}
data.append(&mut row);
} else {
// EOF
break;
}
}
return (headers, data);
}
fn read_row_as_strings<R: BufRead>(reader: &mut R, buffer: &mut String) -> Vec<String> {
buffer.clear();
let bytes_read = reader.read_line(buffer).expect("IO error");
if bytes_read == 0 {
panic!("Unexpected EOF");
}
let mut result = Vec::new();
let mut entries_iter = buffer.trim().split(',');
loop {
if let Some(entry) = entries_iter.next() {
if entry.starts_with('"') {
if entry.ends_with('"') {
result.push(String::from(&entry[1..(entry.len() - 1)]));
} else {
let mut full_entry = String::from(&entry[1..]);
// Read remainder of quoted entry
loop {
if let Some(entry_part) = entries_iter.next() {
if entry_part.ends_with('"') {
// End of quoted entry
full_entry.push_str(&entry_part[..(entry_part.len() - 1)]);
result.push(full_entry);
break;
} else {
// Middle of quoted entry
full_entry.push_str(entry_part);
full_entry.push_str(&",");
}
} else {
panic!("Unexpected EOL while reading quoted CSV entry");
}
}
}
} else {
result.push(String::from(entry));
}
} else {
// EOL
break;
}
}
return result;
}
fn read_row_as_floats<R: BufRead>(reader: &mut R, buffer: &mut String, row: &mut Vec<f64>) -> bool {
buffer.clear();
let bytes_read = reader.read_line(buffer).expect("IO error");
if bytes_read == 0 {
// EOF
return false;
}
let mut entries_iter = buffer.trim().split(',');
loop {
if let Some(entry) = entries_iter.next() {
if entry.starts_with('"') {
if entry.ends_with('"') {
row.push(parse_float(&entry[1..(entry.len() - 1)]));
} else {
// Float cannot have a comma in it
panic!("Malformed float");
}
} else {
row.push(parse_float(entry));
}
} else {
// EOL
break;
}
}
return true;
}
fn parse_float(s: &str) -> f64 {
let value = match s {
"inf" => f64::INFINITY,
_ => s.parse().expect("Malformed float")
};
return value;
}

View File

@ -1,5 +1,6 @@
pub mod intcox; pub mod intcox;
pub mod turnbull; pub mod turnbull;
mod csv;
mod pava; mod pava;
mod term; mod term;

View File

@ -17,17 +17,17 @@
const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64 const Z_97_5: f64 = 1.959964; // This is the limit of resolution for an f64
const CHI2_1DF_95: f64 = 3.8414588; const CHI2_1DF_95: f64 = 3.8414588;
use core::mem::MaybeUninit; use std::fs::File;
use std::io; use std::io::{self, BufReader};
use clap::{Args, ValueEnum}; use clap::{Args, ValueEnum};
use csv::{Reader, StringRecord};
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressDrawTarget, ProgressStyle};
use nalgebra::{Const, DMatrix, DVector, Dyn, Matrix2xX}; use nalgebra::{DMatrix, DVector, Matrix2xX};
use prettytable::{Table, format, row}; use prettytable::{Table, format, row};
use rayon::prelude::*; use rayon::prelude::*;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::csv::read_csv;
use crate::pava::monotonic_regression_pava; use crate::pava::monotonic_regression_pava;
use crate::term::UnconditionalTermLike; use crate::term::UnconditionalTermLike;
@ -134,45 +134,21 @@ pub fn main(args: TurnbullArgs) {
pub fn read_data(path: &str) -> Matrix2xX<f64> { pub fn read_data(path: &str) -> Matrix2xX<f64> {
// Read CSV into memory // Read CSV into memory
let _headers: StringRecord; let (_headers, records) = match path {
let records: Vec<StringRecord>; "-" => read_csv(io::stdin().lock()),
if path == "-" { _ => read_csv(BufReader::new(File::open(path).expect("IO error")))
let mut csv_reader = Reader::from_reader(io::stdin()); };
_headers = csv_reader.headers().unwrap().clone();
records = csv_reader.records().map(|r| r.unwrap()).collect();
} else {
let mut csv_reader = Reader::from_path(path).unwrap();
_headers = csv_reader.headers().unwrap().clone();
records = csv_reader.records().map(|r| r.unwrap()).collect();
}
// Read data into matrices // Read data into matrices
// Represent data_times as 2xX rather than Xx2 matrix to allow par_column_iter in code_times_as_indexes (no par_row_iter) // Represent data_times as 2xX rather than Xx2 matrix to allow par_column_iter in code_times_as_indexes (no par_row_iter)
let mut data_times: Matrix2xX<MaybeUninit<f64>> = Matrix2xX::uninit( // Serendipitously, from_vec fills column-by-column
Const::<2>, // Left time, right time let data_times = Matrix2xX::from_vec(records);
Dyn(records.len())
);
// Parse data
for (i, row) in records.iter().enumerate() {
for (j, item) in row.iter().enumerate() {
let value = match item {
"inf" => f64::INFINITY,
_ => item.parse().expect("Malformed float")
};
data_times[(j, i)].write(value);
}
}
// TODO: Fail on left time > right time // TODO: Fail on left time > right time
// TODO: Fail on left time < 0 // TODO: Fail on left time < 0
// SAFETY: assume_init is OK because we initialised all values above return data_times;
unsafe {
return data_times.assume_init();
}
} }
struct TurnbullData { struct TurnbullData {