diff --git a/austax/__init__.py b/austax/__init__.py index 68c51c6..93997f0 100644 --- a/austax/__init__.py +++ b/austax/__init__.py @@ -49,13 +49,16 @@ def plugin_init(): drcr.plugins.transaction_providers.append(make_tax_transactions) @assert_aud -def make_tax_transactions(): - report = tax_summary_report() - tax_amount = report.by_id('total_tax').amount - +def make_tax_transactions(start_date=None, end_date=None): # Get EOFY date dt = eofy_date() + if (start_date is not None and start_date > dt) or (end_date is not None and end_date < dt): + return [] + + report = tax_summary_report() + tax_amount = report.by_id('total_tax').amount - report.by_id('offsets').amount + # Estimated tax payable transactions = [Transaction( dt=dt, diff --git a/drcr/webapp.py b/drcr/webapp.py index 6f55560..312ca9c 100644 --- a/drcr/webapp.py +++ b/drcr/webapp.py @@ -32,23 +32,29 @@ import time app.config['SQLALCHEMY_RECORD_QUERIES'] = app.debug db.init_app(app) +def limit_query_dt(query, field, start_date=None, end_date=None): + """Helper function to limit the query between the start and end dates""" + + if start_date and end_date: + return query.where((field >= start_date) & (field <= end_date)) + if start_date: + return query.where(field >= start_date) + if end_date: + return query.where(field <= end_date) + return query + def all_transactions(start_date=None, end_date=None, join_postings=True): """Return all transactions, including from DB and API""" # All Transactions in database between start_date and end_date query = db.select(Transaction) - if start_date and end_date: - query = query.where((Transaction.dt >= start_date) & (Transaction.dt <= end_date)) - elif start_date: - query = query.where(Transaction.dt >= start_date) - elif end_date: - query = query.where(Transaction.dt <= end_date) + query = limit_query_dt(query, Transaction.dt, start_date, end_date) if join_postings: query = query.options(db.selectinload(Transaction.postings)) transactions = db.session.scalars(query).all() - transactions.extend(api_transactions(start_date, end_date)) + transactions.extend(api_transactions(start_date=start_date, end_date=end_date)) return transactions @@ -58,13 +64,13 @@ def api_transactions(start_date=None, end_date=None): transactions = [] # Unreconciled StatementLines - # FIXME: Filter by start_date and end_date - transactions.extend(line.into_transaction() for line in StatementLine.query.filter(StatementLine.reconciliation == None)) + query = db.select(StatementLine).where(StatementLine.reconciliation == None) + query = limit_query_dt(query, StatementLine.dt, start_date, end_date) + transactions.extend(line.into_transaction() for line in db.session.scalars(query).all()) # Plugins - # FIXME: Filter by start_date and end_date for transaction_provider in transaction_providers: - transactions.extend(transaction_provider()) + transactions.extend(transaction_provider(start_date=start_date, end_date=end_date)) return transactions