From 280a2090d9d52cf300031e9e8cc7ae2a31ca549e Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Tue, 27 May 2025 17:28:34 +1000 Subject: [PATCH] Execute reporting steps in parallel --- src/main.rs | 13 +++++-- src/reporting/calculator.rs | 2 +- src/reporting/executor.rs | 67 +++++++++++++++++++++++++++++++------ src/reporting/mod.rs | 6 ++-- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2037235..7ab29a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,6 +16,8 @@ along with this program. If not, see . */ +use std::sync::Arc; + use chrono::NaiveDate; use libdrcr::db::DbConnection; use libdrcr::reporting::builders::register_dynamic_builders; @@ -41,10 +43,11 @@ async fn main() { NaiveDate::from_ymd_opt(2025, 6, 30).unwrap(), "$".to_string(), ); - register_lookup_fns(&mut context); register_dynamic_builders(&mut context); + let context = Arc::new(context); + // Print Graphviz let targets = vec![ @@ -86,7 +89,9 @@ async fn main() { }, ]; - let products = generate_report(targets, &context).await.unwrap(); + let products = generate_report(targets, Arc::clone(&context)) + .await + .unwrap(); let result = products .get_or_err(&ReportingProductId { name: "AllTransactionsExceptEarningsToEquity", @@ -120,7 +125,9 @@ async fn main() { }, ]; - let products = generate_report(targets, &context).await.unwrap(); + let products = generate_report(targets, Arc::clone(&context)) + .await + .unwrap(); let result = products .get_or_err(&ReportingProductId { name: "BalanceSheet", diff --git a/src/reporting/calculator.rs b/src/reporting/calculator.rs index 35de398..410aff2 100644 --- a/src/reporting/calculator.rs +++ b/src/reporting/calculator.rs @@ -179,7 +179,7 @@ fn build_step_for_product( } /// Check whether the [ReportingStep] would be ready to execute, if the given previous steps have already completed -fn would_be_ready_to_execute( +pub(crate) fn would_be_ready_to_execute( step: &Box, steps: &Vec>, dependencies: &ReportingGraphDependencies, diff --git a/src/reporting/executor.rs b/src/reporting/executor.rs index c8830e8..dd42c55 100644 --- a/src/reporting/executor.rs +++ b/src/reporting/executor.rs @@ -16,10 +16,12 @@ along with this program. If not, see . */ -use tokio::sync::RwLock; +use std::sync::Arc; + +use tokio::{sync::RwLock, task::JoinSet}; use super::{ - calculator::ReportingGraphDependencies, + calculator::{would_be_ready_to_execute, ReportingGraphDependencies}, types::{ReportingContext, ReportingProducts, ReportingStep}, }; @@ -28,19 +30,62 @@ pub enum ReportingExecutionError { DependencyNotAvailable { message: String }, } +async fn execute_step( + step_idx: usize, + steps: Arc>>, + dependencies: Arc, + context: Arc, + products: Arc>, +) -> (usize, Result) { + let step = &steps[step_idx]; + let result = step + .execute(&*context, &*steps, &*dependencies, &*products) + .await; + + (step_idx, result) +} + pub async fn execute_steps( steps: Vec>, dependencies: ReportingGraphDependencies, - context: &ReportingContext, + context: Arc, ) -> Result { - let products = RwLock::new(ReportingProducts::new()); + let products = Arc::new(RwLock::new(ReportingProducts::new())); - for step in steps.iter() { - // Execute the step - // TODO: Do this in parallel - let mut new_products = step - .execute(context, &steps, &dependencies, &products) - .await?; + // Prepare for async + let steps = Arc::new(steps); + let dependencies = Arc::new(dependencies); + + // Execute steps asynchronously + let mut handles = JoinSet::new(); + let mut steps_done = Vec::new(); + let mut steps_remaining = (0..steps.len()).collect::>(); + + while steps_done.len() != steps.len() { + // Execute each step which is ready to run + for step_idx in steps_remaining.iter().copied().collect::>() { + // Check if ready to run + if would_be_ready_to_execute(&steps[step_idx], &steps, &dependencies, &steps_done) { + // Spawn new task + // Unfortunately the compiler cannot guarantee lifetimes are correct, so we must pass Arc across thread boundaries + handles.spawn(execute_step( + step_idx, + Arc::clone(&steps), + Arc::clone(&dependencies), + Arc::clone(&context), + Arc::clone(&products), + )); + steps_remaining + .remove(steps_remaining.iter().position(|i| *i == step_idx).unwrap()); + } + } + + // Join next result + let (step_idx, result) = handles.join_next().await.unwrap().unwrap(); + let step = &steps[step_idx]; + steps_done.push(step_idx); + + let mut new_products = result?; // Sanity check the new products for (product_id, _product) in new_products.map().iter() { @@ -71,5 +116,5 @@ pub async fn execute_steps( products.write().await.append(&mut new_products); } - Ok(products.into_inner()) + Ok(Arc::into_inner(products).unwrap().into_inner()) } diff --git a/src/reporting/mod.rs b/src/reporting/mod.rs index 9214bd3..310e076 100644 --- a/src/reporting/mod.rs +++ b/src/reporting/mod.rs @@ -16,6 +16,8 @@ along with this program. If not, see . */ +use std::sync::Arc; + use calculator::{steps_for_targets, ReportingCalculationError}; use executor::{execute_steps, ReportingExecutionError}; use types::{ReportingContext, ReportingProductId, ReportingProducts}; @@ -50,10 +52,10 @@ impl From for ReportingError { /// Helper function to call [steps_for_targets] followed by [execute_steps]. pub async fn generate_report( targets: Vec, - context: &ReportingContext, + context: Arc, ) -> Result { // Solve dependencies - let (sorted_steps, dependencies) = steps_for_targets(targets, context)?; + let (sorted_steps, dependencies) = steps_for_targets(targets, &*context)?; // Execute steps let products = execute_steps(sorted_steps, dependencies, context).await?;