diff --git a/src/main.rs b/src/main.rs
index 2037235..7ab29a9 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -16,6 +16,8 @@
along with this program. If not, see .
*/
+use std::sync::Arc;
+
use chrono::NaiveDate;
use libdrcr::db::DbConnection;
use libdrcr::reporting::builders::register_dynamic_builders;
@@ -41,10 +43,11 @@ async fn main() {
NaiveDate::from_ymd_opt(2025, 6, 30).unwrap(),
"$".to_string(),
);
-
register_lookup_fns(&mut context);
register_dynamic_builders(&mut context);
+ let context = Arc::new(context);
+
// Print Graphviz
let targets = vec![
@@ -86,7 +89,9 @@ async fn main() {
},
];
- let products = generate_report(targets, &context).await.unwrap();
+ let products = generate_report(targets, Arc::clone(&context))
+ .await
+ .unwrap();
let result = products
.get_or_err(&ReportingProductId {
name: "AllTransactionsExceptEarningsToEquity",
@@ -120,7 +125,9 @@ async fn main() {
},
];
- let products = generate_report(targets, &context).await.unwrap();
+ let products = generate_report(targets, Arc::clone(&context))
+ .await
+ .unwrap();
let result = products
.get_or_err(&ReportingProductId {
name: "BalanceSheet",
diff --git a/src/reporting/calculator.rs b/src/reporting/calculator.rs
index 35de398..410aff2 100644
--- a/src/reporting/calculator.rs
+++ b/src/reporting/calculator.rs
@@ -179,7 +179,7 @@ fn build_step_for_product(
}
/// Check whether the [ReportingStep] would be ready to execute, if the given previous steps have already completed
-fn would_be_ready_to_execute(
+pub(crate) fn would_be_ready_to_execute(
step: &Box,
steps: &Vec>,
dependencies: &ReportingGraphDependencies,
diff --git a/src/reporting/executor.rs b/src/reporting/executor.rs
index c8830e8..dd42c55 100644
--- a/src/reporting/executor.rs
+++ b/src/reporting/executor.rs
@@ -16,10 +16,12 @@
along with this program. If not, see .
*/
-use tokio::sync::RwLock;
+use std::sync::Arc;
+
+use tokio::{sync::RwLock, task::JoinSet};
use super::{
- calculator::ReportingGraphDependencies,
+ calculator::{would_be_ready_to_execute, ReportingGraphDependencies},
types::{ReportingContext, ReportingProducts, ReportingStep},
};
@@ -28,19 +30,62 @@ pub enum ReportingExecutionError {
DependencyNotAvailable { message: String },
}
+async fn execute_step(
+ step_idx: usize,
+ steps: Arc>>,
+ dependencies: Arc,
+ context: Arc,
+ products: Arc>,
+) -> (usize, Result) {
+ let step = &steps[step_idx];
+ let result = step
+ .execute(&*context, &*steps, &*dependencies, &*products)
+ .await;
+
+ (step_idx, result)
+}
+
pub async fn execute_steps(
steps: Vec>,
dependencies: ReportingGraphDependencies,
- context: &ReportingContext,
+ context: Arc,
) -> Result {
- let products = RwLock::new(ReportingProducts::new());
+ let products = Arc::new(RwLock::new(ReportingProducts::new()));
- for step in steps.iter() {
- // Execute the step
- // TODO: Do this in parallel
- let mut new_products = step
- .execute(context, &steps, &dependencies, &products)
- .await?;
+ // Prepare for async
+ let steps = Arc::new(steps);
+ let dependencies = Arc::new(dependencies);
+
+ // Execute steps asynchronously
+ let mut handles = JoinSet::new();
+ let mut steps_done = Vec::new();
+ let mut steps_remaining = (0..steps.len()).collect::>();
+
+ while steps_done.len() != steps.len() {
+ // Execute each step which is ready to run
+ for step_idx in steps_remaining.iter().copied().collect::>() {
+ // Check if ready to run
+ if would_be_ready_to_execute(&steps[step_idx], &steps, &dependencies, &steps_done) {
+ // Spawn new task
+ // Unfortunately the compiler cannot guarantee lifetimes are correct, so we must pass Arc across thread boundaries
+ handles.spawn(execute_step(
+ step_idx,
+ Arc::clone(&steps),
+ Arc::clone(&dependencies),
+ Arc::clone(&context),
+ Arc::clone(&products),
+ ));
+ steps_remaining
+ .remove(steps_remaining.iter().position(|i| *i == step_idx).unwrap());
+ }
+ }
+
+ // Join next result
+ let (step_idx, result) = handles.join_next().await.unwrap().unwrap();
+ let step = &steps[step_idx];
+ steps_done.push(step_idx);
+
+ let mut new_products = result?;
// Sanity check the new products
for (product_id, _product) in new_products.map().iter() {
@@ -71,5 +116,5 @@ pub async fn execute_steps(
products.write().await.append(&mut new_products);
}
- Ok(products.into_inner())
+ Ok(Arc::into_inner(products).unwrap().into_inner())
}
diff --git a/src/reporting/mod.rs b/src/reporting/mod.rs
index 9214bd3..310e076 100644
--- a/src/reporting/mod.rs
+++ b/src/reporting/mod.rs
@@ -16,6 +16,8 @@
along with this program. If not, see .
*/
+use std::sync::Arc;
+
use calculator::{steps_for_targets, ReportingCalculationError};
use executor::{execute_steps, ReportingExecutionError};
use types::{ReportingContext, ReportingProductId, ReportingProducts};
@@ -50,10 +52,10 @@ impl From for ReportingError {
/// Helper function to call [steps_for_targets] followed by [execute_steps].
pub async fn generate_report(
targets: Vec,
- context: &ReportingContext,
+ context: Arc,
) -> Result {
// Solve dependencies
- let (sorted_steps, dependencies) = steps_for_targets(targets, context)?;
+ let (sorted_steps, dependencies) = steps_for_targets(targets, &*context)?;
// Execute steps
let products = execute_steps(sorted_steps, dependencies, context).await?;