Port from raw SQLAlchemy to Flask-SQLAlchemy

This commit is contained in:
RunasSudo 2023-01-02 16:59:12 +11:00
parent 3e6e1f18ed
commit d8caedd2dd
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
7 changed files with 76 additions and 91 deletions

4
.gitignore vendored
View File

@ -1,6 +1,4 @@
__pycache__ __pycache__
/drcr/config.py /drcr/config.py
/mydata /instance
/scripts
/venv /venv
*.db

View File

@ -14,17 +14,6 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import create_engine from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine('sqlite:///drcr.db') db = SQLAlchemy()
db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine))
Base = declarative_base()
Base.query = db_session.query_property()
def init_db():
from .general_journal import models
from .statements import models
Base.metadata.create_all(bind=engine)

View File

@ -14,48 +14,45 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String from ..database import db
from sqlalchemy.orm import relationship
from ..database import Base
from ..models import Amount, Posting, Transaction from ..models import Amount, Posting, Transaction
class GeneralJournalTransaction(Base, Transaction): class GeneralJournalTransaction(db.Model, Transaction):
__tablename__ = 'general_journal_transactions' __tablename__ = 'general_journal_transactions'
id = Column(Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
dt = Column(DateTime) dt = db.Column(db.DateTime)
description = Column(String) description = db.Column(db.String)
postings = relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan') postings = db.relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan')
class GeneralJournalPosting(Base, Posting): class GeneralJournalPosting(db.Model, Posting):
__tablename__ = 'general_journal_postings' __tablename__ = 'general_journal_postings'
id = Column(Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
transaction_id = Column(Integer, ForeignKey('general_journal_transactions.id')) transaction_id = db.Column(db.Integer, db.ForeignKey('general_journal_transactions.id'))
description = Column(String) description = db.Column(db.String)
account = Column(String) account = db.Column(db.String)
quantity = Column(Integer) quantity = db.Column(db.Integer)
commodity = Column(String) commodity = db.Column(db.String)
transaction = relationship('GeneralJournalTransaction', back_populates='postings') transaction = db.relationship('GeneralJournalTransaction', back_populates='postings')
def amount(self): def amount(self):
return Amount(self.quantity, self.commodity) return Amount(self.quantity, self.commodity)
class BalanceAssertion(Base): class BalanceAssertion(db.Model):
__tablename__ = 'balance_assertions' __tablename__ = 'balance_assertions'
id = Column(Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
dt = Column(DateTime) dt = db.Column(db.DateTime)
description = Column(String) description = db.Column(db.String)
account = Column(String) account = db.Column(db.String)
quantity = Column(Integer) quantity = db.Column(db.Integer)
commodity = Column(String) commodity = db.Column(db.String)
def balance(self): def balance(self):
return Amount(self.quantity, self.commodity) return Amount(self.quantity, self.commodity)

View File

@ -17,7 +17,7 @@
from flask import abort, redirect, render_template, request from flask import abort, redirect, render_template, request
from .. import AMOUNT_DPS from .. import AMOUNT_DPS
from ..database import db_session from ..database import db
from ..models import TrialBalancer from ..models import TrialBalancer
from ..webapp import all_transactions, app from ..webapp import all_transactions, app
from .models import Amount, BalanceAssertion, GeneralJournalPosting, GeneralJournalTransaction from .models import Amount, BalanceAssertion, GeneralJournalPosting, GeneralJournalTransaction
@ -58,14 +58,14 @@ def general_journal_new():
transaction.assert_valid() transaction.assert_valid()
db_session.add(transaction) db.session.add(transaction)
db_session.commit() db.session.commit()
return redirect('/general-journal') return redirect('/general-journal')
@app.route('/general-journal/edit', methods=['GET', 'POST']) @app.route('/general-journal/edit', methods=['GET', 'POST'])
def general_journal_edit(): def general_journal_edit():
transaction = db_session.get(GeneralJournalTransaction, request.args['id']) transaction = db.session.get(GeneralJournalTransaction, request.args['id'])
if not transaction: if not transaction:
abort(404) abort(404)
@ -91,7 +91,7 @@ def general_journal_edit():
transaction.assert_valid() transaction.assert_valid()
db_session.commit() db.session.commit()
return redirect('/general-journal') return redirect('/general-journal')
@ -132,14 +132,14 @@ def balance_assertions_new():
quantity=quantity, quantity=quantity,
commodity='$' commodity='$'
) )
db_session.add(assertion) db.session.add(assertion)
db_session.commit() db.session.commit()
return redirect('/balance-assertions') return redirect('/balance-assertions')
@app.route('/balance-assertions/edit', methods=['GET', 'POST']) @app.route('/balance-assertions/edit', methods=['GET', 'POST'])
def balance_assertions_edit(): def balance_assertions_edit():
assertion = db_session.get(BalanceAssertion, request.args['id']) assertion = db.session.get(BalanceAssertion, request.args['id'])
if not assertion: if not assertion:
abort(404) abort(404)
@ -156,6 +156,6 @@ def balance_assertions_edit():
assertion.account = request.form['account'] assertion.account = request.form['account']
assertion.quantity = quantity assertion.quantity = quantity
db_session.commit() db.session.commit()
return redirect('/balance-assertions') return redirect('/balance-assertions')

View File

@ -14,25 +14,22 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String from ..database import db
from sqlalchemy.orm import relationship
from ..database import Base, db_session
from ..models import Amount, Posting, Transaction from ..models import Amount, Posting, Transaction
class StatementLine(Base): class StatementLine(db.Model):
__tablename__ = 'statement_lines' __tablename__ = 'statement_lines'
id = Column(Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
source_account = Column(String) source_account = db.Column(db.String)
dt = Column(DateTime) dt = db.Column(db.DateTime)
description = Column(String) description = db.Column(db.String)
quantity = Column(Integer) quantity = db.Column(db.Integer)
balance = Column(Integer) balance = db.Column(db.Integer)
commodity = Column(String) commodity = db.Column(db.String)
postings = relationship('StatementLinePosting', back_populates='statement_line') postings = db.relationship('StatementLinePosting', back_populates='statement_line')
def amount(self): def amount(self):
return Amount(self.quantity, self.commodity) return Amount(self.quantity, self.commodity)
@ -69,17 +66,17 @@ class StatementLine(Base):
for posting in self.postings: for posting in self.postings:
# TODO: Will be wonky if transaction covers multiple StatementLines # TODO: Will be wonky if transaction covers multiple StatementLines
db_session.delete(posting.transaction) db.session.delete(posting.transaction)
class StatementLineTransaction(Base, Transaction): class StatementLineTransaction(db.Model, Transaction):
__tablename__ = 'statement_line_transactions' __tablename__ = 'statement_line_transactions'
id = Column(Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
dt = Column(DateTime) dt = db.Column(db.DateTime)
description = Column(String) description = db.Column(db.String)
postings = relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete') postings = db.relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete')
def charge_account(self, source_account): def charge_account(self, source_account):
if len(self.postings) > 2: if len(self.postings) > 2:
@ -89,20 +86,20 @@ class StatementLineTransaction(Base, Transaction):
if posting.account != source_account: if posting.account != source_account:
return posting.account return posting.account
class StatementLinePosting(Base, Posting): class StatementLinePosting(db.Model, Posting):
__tablename__ = 'statement_line_postings' __tablename__ = 'statement_line_postings'
id = Column(Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
transaction_id = Column(Integer, ForeignKey('statement_line_transactions.id')) transaction_id = db.Column(db.Integer, db.ForeignKey('statement_line_transactions.id'))
line_id = Column(Integer, ForeignKey('statement_lines.id')) line_id = db.Column(db.Integer, db.ForeignKey('statement_lines.id'))
description = Column(String) description = db.Column(db.String)
account = Column(String) account = db.Column(db.String)
quantity = Column(Integer) quantity = db.Column(db.Integer)
commodity = Column(String) commodity = db.Column(db.String)
transaction = relationship('StatementLineTransaction', back_populates='postings') transaction = db.relationship('StatementLineTransaction', back_populates='postings')
statement_line = relationship('StatementLine', back_populates='postings') statement_line = db.relationship('StatementLine', back_populates='postings')
def amount(self): def amount(self):
return Amount(self.quantity, self.commodity) return Amount(self.quantity, self.commodity)

View File

@ -17,7 +17,7 @@
from flask import abort, redirect, render_template, request from flask import abort, redirect, render_template, request
from .. import AMOUNT_DPS from .. import AMOUNT_DPS
from ..database import db_session from ..database import db
from ..webapp import app from ..webapp import app
from .models import StatementLine, StatementLinePosting, StatementLineTransaction from .models import StatementLine, StatementLinePosting, StatementLineTransaction
@ -32,7 +32,7 @@ def statement_lines():
@app.route('/statement-lines/charge', methods=['POST']) @app.route('/statement-lines/charge', methods=['POST'])
def statement_line_charge(): def statement_line_charge():
statement_line = db_session.get(StatementLine, request.form['line-id']) statement_line = db.session.get(StatementLine, request.form['line-id'])
if not statement_line: if not statement_line:
abort(404) abort(404)
@ -48,14 +48,14 @@ def statement_line_charge():
] ]
) )
db_session.add(transaction) db.session.add(transaction)
db_session.commit() db.session.commit()
return 'OK' return 'OK'
@app.route('/statement-lines/edit-transaction', methods=['GET', 'POST']) @app.route('/statement-lines/edit-transaction', methods=['GET', 'POST'])
def statement_line_edit_transaction(): def statement_line_edit_transaction():
statement_line = db_session.get(StatementLine, request.args['line-id']) statement_line = db.session.get(StatementLine, request.args['line-id'])
if not statement_line: if not statement_line:
abort(404) abort(404)
@ -99,8 +99,8 @@ def statement_line_edit_transaction():
) )
transaction.assert_valid() transaction.assert_valid()
db_session.add(transaction) db.session.add(transaction)
db_session.commit() db.session.commit()
return redirect('/statement-lines') return redirect('/statement-lines')
@ -110,8 +110,8 @@ def statement_line_reconcile_transfer():
if len(line_ids) != 2: if len(line_ids) != 2:
raise Exception('Must select exactly 2 statement lines') raise Exception('Must select exactly 2 statement lines')
line1 = db_session.get(StatementLine, line_ids[0]) line1 = db.session.get(StatementLine, line_ids[0])
line2 = db_session.get(StatementLine, line_ids[1]) line2 = db.session.get(StatementLine, line_ids[1])
# Check same amount # Check same amount
if line1.quantity != -line2.quantity or line1.commodity != line2.commodity: if line1.quantity != -line2.quantity or line1.commodity != line2.commodity:
@ -129,7 +129,7 @@ def statement_line_reconcile_transfer():
StatementLinePosting(statement_line=line2, description=line2.description, account=line2.source_account, quantity=line2.quantity, commodity=line2.commodity) StatementLinePosting(statement_line=line2, description=line2.description, account=line2.source_account, quantity=line2.quantity, commodity=line2.commodity)
] ]
) )
db_session.add(transaction) db.session.add(transaction)
db_session.commit() db.session.commit()
return redirect('/statement-lines') return redirect('/statement-lines')

View File

@ -16,12 +16,16 @@
from flask import Flask from flask import Flask
from .database import db_session, init_db from .database import db
from .general_journal.models import GeneralJournalTransaction from .general_journal.models import GeneralJournalTransaction
from .statements.models import StatementLine, StatementLineTransaction from .statements.models import StatementLine, StatementLineTransaction
app = Flask(__name__) app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///drcr.db'
app.config['SQLALCHEMY_RECORD_QUERIES'] = True
db.init_app(app)
def all_transactions(): def all_transactions():
return ( return (
GeneralJournalTransaction.query.all() + GeneralJournalTransaction.query.all() +
@ -36,7 +40,7 @@ from .statements import views
@app.cli.command('initdb') @app.cli.command('initdb')
def initdb(): def initdb():
init_db() db.create_all()
@app.teardown_appcontext @app.teardown_appcontext
def shutdown_session(exception=None): def shutdown_session(exception=None):