diff --git a/drcr/journal/views.py b/drcr/journal/views.py index f8c4e40..627d3fc 100644 --- a/drcr/journal/views.py +++ b/drcr/journal/views.py @@ -18,7 +18,7 @@ from flask import abort, redirect, render_template, request, url_for from .. import AMOUNT_DPS from ..database import db -from ..models import Amount, Posting, Transaction, TrialBalancer, reporting_commodity +from ..models import Amount, Posting, Transaction, TrialBalancer, queue_invalidate_running_balances, reporting_commodity from ..webapp import all_accounts, all_transactions, app from .models import BalanceAssertion from ..statements.models import StatementLineReconciliation @@ -58,10 +58,13 @@ def journal_new_transaction(): commodity=amount.commodity ) transaction.postings.append(posting) + + # Invalidate future running balances + queue_invalidate_running_balances(account, transaction.dt) transaction.assert_valid() - db.session.add(transaction) + db.session.commit() return redirect(request.form.get('referrer', '') or url_for('journal')) @@ -103,6 +106,9 @@ def journal_edit_transaction(): commodity=amount.commodity ) new_postings.append(posting) + + # Invalidate future running balances + queue_invalidate_running_balances(account, transaction.dt) # Fix up reconciliations for old_posting in transaction.postings: diff --git a/drcr/models.py b/drcr/models.py index 5a587d8..77d6504 100644 --- a/drcr/models.py +++ b/drcr/models.py @@ -70,15 +70,26 @@ class Posting(db.Model): transaction = db.relationship('Transaction', back_populates='postings') - def __init__(self, description=None, account=None, quantity=None, commodity=None): + def __init__(self, description=None, account=None, quantity=None, commodity=None, running_balance=None): self.description = description self.account = account self.quantity = quantity self.commodity = commodity + self.running_balance = running_balance def amount(self): return Amount(self.quantity, self.commodity) +def queue_invalidate_running_balances(account, dt_from): + """ + Invalidate running_balances for Postings in the specified account, from the given date onwards + + NOTE: Does not call db.session.commit() + """ + + for posting in db.session.scalars(db.select(Posting).join(Posting.transaction).where((Transaction.dt >= dt_from) & (Posting.account == account))).all(): + posting.running_balance = None + class Amount: __slots__ = ['quantity', 'commodity'] @@ -204,6 +215,26 @@ class TrialBalancer: def from_cached(cls, start_date=None, end_date=None): """Obtain a TrialBalancer based on the cached running_balance""" + # First, recompute any running_balance if required + stale_accounts = db.session.scalars('SELECT DISTINCT account FROM postings WHERE running_balance IS NULL').all() + if stale_accounts: + # Get all relevant Postings in database in correct order + # FIXME: Recompute balances only from the last non-stale balance to be more efficient + postings = db.session.scalars(db.select(Posting).join(Posting.transaction).where(Posting.account.in_(stale_accounts)).order_by(Transaction.dt, Transaction.id)).all() + + accounts = {} + + for posting in postings: + if posting.account not in accounts: + accounts[posting.account] = Amount(0, reporting_commodity()) + + # FIXME: Handle commodities better (ensure compatible commodities) + accounts[posting.account].quantity += posting.amount().as_cost().quantity + + posting.running_balance = accounts[posting.account].quantity + + db.session.commit() + if start_date is not None: result_start_date = cls()