diff --git a/austax/__init__.py b/austax/__init__.py index 4cd4bcc..68c51c6 100644 --- a/austax/__init__.py +++ b/austax/__init__.py @@ -79,8 +79,7 @@ def make_tax_transactions(): )) # Get trial balance - balancer = TrialBalancer() - balancer.apply_transactions(db.session.scalars(db.select(Transaction).options(db.selectinload(Transaction.postings))).all()) + balancer = TrialBalancer.from_cached() accounts = dict(sorted(balancer.accounts.items())) diff --git a/austax/reports.py b/austax/reports.py index 744f425..2b441ff 100644 --- a/austax/reports.py +++ b/austax/reports.py @@ -84,8 +84,7 @@ def study_loan_repayment(year, taxable_income, rfb_grossedup): @assert_aud def tax_summary_report(): # Get trial balance - balancer = TrialBalancer() - balancer.apply_transactions(db.session.scalars(db.select(Transaction).where((Transaction.dt >= sofy_date()) & (Transaction.dt <= eofy_date())).options(db.selectinload(Transaction.postings))).all()) + balancer = TrialBalancer.from_cached(start_date=sofy_date(), end_date=eofy_date()) accounts = dict(sorted(balancer.accounts.items())) diff --git a/drcr/models.py b/drcr/models.py index 1ba222c..5a587d8 100644 --- a/drcr/models.py +++ b/drcr/models.py @@ -64,6 +64,10 @@ class Posting(db.Model): quantity = db.Column(db.Integer) commodity = db.Column(db.String) + # Running balance of the account in units of reporting_commodity + # Only takes into consideration Transactions stored in database, not API-generated ones + running_balance = db.Column(db.Integer) + transaction = db.relationship('Transaction', back_populates='postings') def __init__(self, description=None, account=None, quantity=None, commodity=None): @@ -196,13 +200,83 @@ class TrialBalancer: def __init__(self): self.accounts = {} + @classmethod + def from_cached(cls, start_date=None, end_date=None): + """Obtain a TrialBalancer based on the cached running_balance""" + + if start_date is not None: + result_start_date = cls() + + # First SELECT the last applicable dt by account + # Then, among the transactions with that dt, SELECT the last applicable transaction_id + # Then extract the running_balance for each account at that transaction_id + running_balances = db.session.execute(''' + SELECT p3.account, running_balance FROM + ( + SELECT p1.account, max(p2.transaction_id) AS max_tid FROM + ( + SELECT account, max(dt) AS max_dt FROM postings JOIN transactions ON postings.transaction_id = transactions.id WHERE dt < :start_date GROUP BY account + ) p1 + JOIN postings p2 ON p1.account = p2.account AND p1.max_dt = transactions.dt JOIN transactions ON p2.transaction_id = transactions.id GROUP BY p2.account + ) p3 + JOIN postings p4 ON p3.account = p4.account AND p3.max_tid = p4.transaction_id + ''', {'start_date': start_date}) + + for running_balance in running_balances.all(): + result_start_date.accounts[running_balance.account] = Amount(running_balance.running_balance, reporting_commodity()) + + if end_date is None: + result = cls() + + running_balances = db.session.execute(''' + SELECT p3.account, running_balance FROM + ( + SELECT p1.account, max(p2.transaction_id) AS max_tid FROM + ( + SELECT account, max(dt) AS max_dt FROM postings JOIN transactions ON postings.transaction_id = transactions.id GROUP BY account + ) p1 + JOIN postings p2 ON p1.account = p2.account AND p1.max_dt = transactions.dt JOIN transactions ON p2.transaction_id = transactions.id GROUP BY p2.account + ) p3 + JOIN postings p4 ON p3.account = p4.account AND p3.max_tid = p4.transaction_id + ''') + + for running_balance in running_balances.all(): + result.accounts[running_balance.account] = Amount(running_balance.running_balance, reporting_commodity()) + + if end_date is not None: + result = cls() + + running_balances = db.session.execute(''' + SELECT p3.account, running_balance FROM + ( + SELECT p1.account, max(p2.transaction_id) AS max_tid FROM + ( + SELECT account, max(dt) AS max_dt FROM postings JOIN transactions ON postings.transaction_id = transactions.id WHERE dt <= :end_date GROUP BY account + ) p1 + JOIN postings p2 ON p1.account = p2.account AND p1.max_dt = transactions.dt JOIN transactions ON p2.transaction_id = transactions.id GROUP BY p2.account + ) p3 + JOIN postings p4 ON p3.account = p4.account AND p3.max_tid = p4.transaction_id + ''', {'end_date': end_date}) + + for running_balance in running_balances.all(): + result.accounts[running_balance.account] = Amount(running_balance.running_balance, reporting_commodity()) + + # Subtract balances at start_date from balances at end_date if required + if start_date is not None: + for k in result.accounts.keys(): + # If k not in result_start_date, then the balance at start_date was necessarily 0 and subtraction is not required + if k in result_start_date.accounts: + result.accounts[k].quantity -= result_start_date.accounts[k].quantity + + return result + def apply_transactions(self, transactions): for transaction in transactions: for posting in transaction.postings: if posting.account not in self.accounts: self.accounts[posting.account] = Amount(0, reporting_commodity()) - # FIXME: Handle commodities better + # FIXME: Handle commodities better (ensure compatible commodities) self.accounts[posting.account].quantity += posting.amount().as_cost().quantity def transfer_balance(self, source_account, destination_account, description=None): diff --git a/drcr/reports.py b/drcr/reports.py index 73c93da..70872a9 100644 --- a/drcr/reports.py +++ b/drcr/reports.py @@ -17,7 +17,7 @@ from flask import url_for from .models import AccountConfiguration, Amount, TrialBalancer, reporting_commodity -from .webapp import all_transactions, eofy_date, sofy_date +from .webapp import all_transactions, api_transactions, eofy_date, sofy_date from datetime import datetime, timedelta @@ -154,8 +154,8 @@ def entries_for_kind(account_configurations, accounts, kind, neg=False, floor=0) def balance_sheet_report(): # Get trial balance - balancer = TrialBalancer() - balancer.apply_transactions(all_transactions()) + balancer = TrialBalancer.from_cached() + balancer.apply_transactions(api_transactions()) accounts = dict(sorted(balancer.accounts.items())) @@ -207,8 +207,8 @@ def income_statement_report(start_date=None, end_date=None): end_date = eofy_date() # Get trial balance - balancer = TrialBalancer() - balancer.apply_transactions(all_transactions(start_date=start_date, end_date=end_date)) + balancer = TrialBalancer.from_cached(start_date=start_date, end_date=end_date) + balancer.apply_transactions(api_transactions(start_date=start_date, end_date=end_date)) accounts = dict(sorted(balancer.accounts.items())) diff --git a/drcr/views.py b/drcr/views.py index c91ade8..83dd940 100644 --- a/drcr/views.py +++ b/drcr/views.py @@ -20,7 +20,7 @@ from .database import db from .models import AccountConfiguration, Amount, Balance, Posting, TrialBalancer, reporting_commodity from .plugins import account_kinds, advanced_reports, data_sources from .reports import balance_sheet_report, income_statement_report -from .webapp import all_transactions, app +from .webapp import all_transactions, api_transactions, app from itertools import groupby @@ -69,8 +69,8 @@ def general_ledger(): @app.route('/trial-balance') def trial_balance(): - balancer = TrialBalancer() - balancer.apply_transactions(all_transactions()) + balancer = TrialBalancer.from_cached() + balancer.apply_transactions(api_transactions()) total_dr = Amount(sum(v.quantity for v in balancer.accounts.values() if v.quantity > 0), reporting_commodity()) total_cr = Amount(sum(v.quantity for v in balancer.accounts.values() if v.quantity < 0), reporting_commodity()) diff --git a/drcr/webapp.py b/drcr/webapp.py index 8b2324f..6f55560 100644 --- a/drcr/webapp.py +++ b/drcr/webapp.py @@ -22,7 +22,7 @@ app.config.from_file('config.toml', load=toml.load) from flask_sqlalchemy.record_queries import get_recorded_queries from .database import db -from .models import Metadata, Transaction, reporting_commodity +from .models import Amount, Metadata, Transaction, reporting_commodity from .plugins import init_plugins, transaction_providers from .statements.models import StatementLine @@ -33,6 +33,8 @@ app.config['SQLALCHEMY_RECORD_QUERIES'] = app.debug db.init_app(app) def all_transactions(start_date=None, end_date=None, join_postings=True): + """Return all transactions, including from DB and API""" + # All Transactions in database between start_date and end_date query = db.select(Transaction) if start_date and end_date: @@ -46,10 +48,21 @@ def all_transactions(start_date=None, end_date=None, join_postings=True): transactions = db.session.scalars(query).all() + transactions.extend(api_transactions(start_date, end_date)) + + return transactions + +def api_transactions(start_date=None, end_date=None): + """Return only transactions from API""" + + transactions = [] + # Unreconciled StatementLines + # FIXME: Filter by start_date and end_date transactions.extend(line.into_transaction() for line in StatementLine.query.filter(StatementLine.reconciliation == None)) # Plugins + # FIXME: Filter by start_date and end_date for transaction_provider in transaction_providers: transactions.extend(transaction_provider()) @@ -80,10 +93,33 @@ init_plugins() @app.cli.command('initdb') def initdb(): + """Initialise database tables""" + db.create_all() # FIXME: Need to init metadata +@app.cli.command('recache_balances') +def recache_balances(): + """Recompute running_balance for all postings""" + + # Get all Transactions in database in correct order + transactions = db.session.scalars(db.select(Transaction).options(db.selectinload(Transaction.postings)).order_by(Transaction.dt, Transaction.id)).all() + + accounts = {} + + for transaction in transactions: + for posting in transaction.postings: + if posting.account not in accounts: + accounts[posting.account] = Amount(0, reporting_commodity()) + + # FIXME: Handle commodities better (ensure compatible commodities) + accounts[posting.account].quantity += posting.amount().as_cost().quantity + + posting.running_balance = accounts[posting.account].quantity + + db.session.commit() + @app.context_processor def add_reporting_commodity(): return dict(reporting_commodity=reporting_commodity())