diff --git a/.gitignore b/.gitignore index 84920f9..975657e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,4 @@ __pycache__ /drcr/config.py -/mydata -/scripts +/instance /venv -*.db diff --git a/drcr/database.py b/drcr/database.py index 251ab4c..60c844a 100644 --- a/drcr/database.py +++ b/drcr/database.py @@ -14,17 +14,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker -from sqlalchemy.ext.declarative import declarative_base +from flask_sqlalchemy import SQLAlchemy -engine = create_engine('sqlite:///drcr.db') -db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) -Base = declarative_base() -Base.query = db_session.query_property() - -def init_db(): - from .general_journal import models - from .statements import models - - Base.metadata.create_all(bind=engine) +db = SQLAlchemy() diff --git a/drcr/general_journal/models.py b/drcr/general_journal/models.py index 359216f..33fb528 100644 --- a/drcr/general_journal/models.py +++ b/drcr/general_journal/models.py @@ -14,48 +14,45 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from sqlalchemy import Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import relationship - -from ..database import Base +from ..database import db from ..models import Amount, Posting, Transaction -class GeneralJournalTransaction(Base, Transaction): +class GeneralJournalTransaction(db.Model, Transaction): __tablename__ = 'general_journal_transactions' - id = Column(Integer, primary_key=True) + id = db.Column(db.Integer, primary_key=True) - dt = Column(DateTime) - description = Column(String) + dt = db.Column(db.DateTime) + description = db.Column(db.String) - postings = relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan') + postings = db.relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan') -class GeneralJournalPosting(Base, Posting): +class GeneralJournalPosting(db.Model, Posting): __tablename__ = 'general_journal_postings' - id = Column(Integer, primary_key=True) - transaction_id = Column(Integer, ForeignKey('general_journal_transactions.id')) + id = db.Column(db.Integer, primary_key=True) + transaction_id = db.Column(db.Integer, db.ForeignKey('general_journal_transactions.id')) - description = Column(String) - account = Column(String) - quantity = Column(Integer) - commodity = Column(String) + description = db.Column(db.String) + account = db.Column(db.String) + quantity = db.Column(db.Integer) + commodity = db.Column(db.String) - transaction = relationship('GeneralJournalTransaction', back_populates='postings') + transaction = db.relationship('GeneralJournalTransaction', back_populates='postings') def amount(self): return Amount(self.quantity, self.commodity) -class BalanceAssertion(Base): +class BalanceAssertion(db.Model): __tablename__ = 'balance_assertions' - id = Column(Integer, primary_key=True) + id = db.Column(db.Integer, primary_key=True) - dt = Column(DateTime) - description = Column(String) - account = Column(String) - quantity = Column(Integer) - commodity = Column(String) + dt = db.Column(db.DateTime) + description = db.Column(db.String) + account = db.Column(db.String) + quantity = db.Column(db.Integer) + commodity = db.Column(db.String) def balance(self): return Amount(self.quantity, self.commodity) diff --git a/drcr/general_journal/views.py b/drcr/general_journal/views.py index 91f340c..748f1ed 100644 --- a/drcr/general_journal/views.py +++ b/drcr/general_journal/views.py @@ -17,7 +17,7 @@ from flask import abort, redirect, render_template, request from .. import AMOUNT_DPS -from ..database import db_session +from ..database import db from ..models import TrialBalancer from ..webapp import all_transactions, app from .models import Amount, BalanceAssertion, GeneralJournalPosting, GeneralJournalTransaction @@ -58,14 +58,14 @@ def general_journal_new(): transaction.assert_valid() - db_session.add(transaction) - db_session.commit() + db.session.add(transaction) + db.session.commit() return redirect('/general-journal') @app.route('/general-journal/edit', methods=['GET', 'POST']) def general_journal_edit(): - transaction = db_session.get(GeneralJournalTransaction, request.args['id']) + transaction = db.session.get(GeneralJournalTransaction, request.args['id']) if not transaction: abort(404) @@ -91,7 +91,7 @@ def general_journal_edit(): transaction.assert_valid() - db_session.commit() + db.session.commit() return redirect('/general-journal') @@ -132,14 +132,14 @@ def balance_assertions_new(): quantity=quantity, commodity='$' ) - db_session.add(assertion) - db_session.commit() + db.session.add(assertion) + db.session.commit() return redirect('/balance-assertions') @app.route('/balance-assertions/edit', methods=['GET', 'POST']) def balance_assertions_edit(): - assertion = db_session.get(BalanceAssertion, request.args['id']) + assertion = db.session.get(BalanceAssertion, request.args['id']) if not assertion: abort(404) @@ -156,6 +156,6 @@ def balance_assertions_edit(): assertion.account = request.form['account'] assertion.quantity = quantity - db_session.commit() + db.session.commit() return redirect('/balance-assertions') diff --git a/drcr/statements/models.py b/drcr/statements/models.py index 1d6b61d..951865b 100644 --- a/drcr/statements/models.py +++ b/drcr/statements/models.py @@ -14,25 +14,22 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from sqlalchemy import Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import relationship - -from ..database import Base, db_session +from ..database import db from ..models import Amount, Posting, Transaction -class StatementLine(Base): +class StatementLine(db.Model): __tablename__ = 'statement_lines' - id = Column(Integer, primary_key=True) + id = db.Column(db.Integer, primary_key=True) - source_account = Column(String) - dt = Column(DateTime) - description = Column(String) - quantity = Column(Integer) - balance = Column(Integer) - commodity = Column(String) + source_account = db.Column(db.String) + dt = db.Column(db.DateTime) + description = db.Column(db.String) + quantity = db.Column(db.Integer) + balance = db.Column(db.Integer) + commodity = db.Column(db.String) - postings = relationship('StatementLinePosting', back_populates='statement_line') + postings = db.relationship('StatementLinePosting', back_populates='statement_line') def amount(self): return Amount(self.quantity, self.commodity) @@ -69,17 +66,17 @@ class StatementLine(Base): for posting in self.postings: # TODO: Will be wonky if transaction covers multiple StatementLines - db_session.delete(posting.transaction) + db.session.delete(posting.transaction) -class StatementLineTransaction(Base, Transaction): +class StatementLineTransaction(db.Model, Transaction): __tablename__ = 'statement_line_transactions' - id = Column(Integer, primary_key=True) + id = db.Column(db.Integer, primary_key=True) - dt = Column(DateTime) - description = Column(String) + dt = db.Column(db.DateTime) + description = db.Column(db.String) - postings = relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete') + postings = db.relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete') def charge_account(self, source_account): if len(self.postings) > 2: @@ -89,20 +86,20 @@ class StatementLineTransaction(Base, Transaction): if posting.account != source_account: return posting.account -class StatementLinePosting(Base, Posting): +class StatementLinePosting(db.Model, Posting): __tablename__ = 'statement_line_postings' - id = Column(Integer, primary_key=True) - transaction_id = Column(Integer, ForeignKey('statement_line_transactions.id')) - line_id = Column(Integer, ForeignKey('statement_lines.id')) + id = db.Column(db.Integer, primary_key=True) + transaction_id = db.Column(db.Integer, db.ForeignKey('statement_line_transactions.id')) + line_id = db.Column(db.Integer, db.ForeignKey('statement_lines.id')) - description = Column(String) - account = Column(String) - quantity = Column(Integer) - commodity = Column(String) + description = db.Column(db.String) + account = db.Column(db.String) + quantity = db.Column(db.Integer) + commodity = db.Column(db.String) - transaction = relationship('StatementLineTransaction', back_populates='postings') - statement_line = relationship('StatementLine', back_populates='postings') + transaction = db.relationship('StatementLineTransaction', back_populates='postings') + statement_line = db.relationship('StatementLine', back_populates='postings') def amount(self): return Amount(self.quantity, self.commodity) diff --git a/drcr/statements/views.py b/drcr/statements/views.py index 50d1d91..8b3b288 100644 --- a/drcr/statements/views.py +++ b/drcr/statements/views.py @@ -17,7 +17,7 @@ from flask import abort, redirect, render_template, request from .. import AMOUNT_DPS -from ..database import db_session +from ..database import db from ..webapp import app from .models import StatementLine, StatementLinePosting, StatementLineTransaction @@ -32,7 +32,7 @@ def statement_lines(): @app.route('/statement-lines/charge', methods=['POST']) def statement_line_charge(): - statement_line = db_session.get(StatementLine, request.form['line-id']) + statement_line = db.session.get(StatementLine, request.form['line-id']) if not statement_line: abort(404) @@ -48,14 +48,14 @@ def statement_line_charge(): ] ) - db_session.add(transaction) - db_session.commit() + db.session.add(transaction) + db.session.commit() return 'OK' @app.route('/statement-lines/edit-transaction', methods=['GET', 'POST']) def statement_line_edit_transaction(): - statement_line = db_session.get(StatementLine, request.args['line-id']) + statement_line = db.session.get(StatementLine, request.args['line-id']) if not statement_line: abort(404) @@ -99,8 +99,8 @@ def statement_line_edit_transaction(): ) transaction.assert_valid() - db_session.add(transaction) - db_session.commit() + db.session.add(transaction) + db.session.commit() return redirect('/statement-lines') @@ -110,8 +110,8 @@ def statement_line_reconcile_transfer(): if len(line_ids) != 2: raise Exception('Must select exactly 2 statement lines') - line1 = db_session.get(StatementLine, line_ids[0]) - line2 = db_session.get(StatementLine, line_ids[1]) + line1 = db.session.get(StatementLine, line_ids[0]) + line2 = db.session.get(StatementLine, line_ids[1]) # Check same amount if line1.quantity != -line2.quantity or line1.commodity != line2.commodity: @@ -129,7 +129,7 @@ def statement_line_reconcile_transfer(): StatementLinePosting(statement_line=line2, description=line2.description, account=line2.source_account, quantity=line2.quantity, commodity=line2.commodity) ] ) - db_session.add(transaction) - db_session.commit() + db.session.add(transaction) + db.session.commit() return redirect('/statement-lines') diff --git a/drcr/webapp.py b/drcr/webapp.py index 70d22ae..91fe586 100644 --- a/drcr/webapp.py +++ b/drcr/webapp.py @@ -16,12 +16,16 @@ from flask import Flask -from .database import db_session, init_db +from .database import db from .general_journal.models import GeneralJournalTransaction from .statements.models import StatementLine, StatementLineTransaction app = Flask(__name__) +app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///drcr.db' +app.config['SQLALCHEMY_RECORD_QUERIES'] = True +db.init_app(app) + def all_transactions(): return ( GeneralJournalTransaction.query.all() + @@ -36,7 +40,7 @@ from .statements import views @app.cli.command('initdb') def initdb(): - init_db() + db.create_all() @app.teardown_appcontext def shutdown_session(exception=None):