Port from raw SQLAlchemy to Flask-SQLAlchemy

This commit is contained in:
RunasSudo 2023-01-02 16:59:12 +11:00
parent 3e6e1f18ed
commit d8caedd2dd
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
7 changed files with 76 additions and 91 deletions

4
.gitignore vendored
View File

@ -1,6 +1,4 @@
__pycache__
/drcr/config.py
/mydata
/scripts
/instance
/venv
*.db

View File

@ -14,17 +14,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from flask_sqlalchemy import SQLAlchemy
engine = create_engine('sqlite:///drcr.db')
db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine))
Base = declarative_base()
Base.query = db_session.query_property()
def init_db():
from .general_journal import models
from .statements import models
Base.metadata.create_all(bind=engine)
db = SQLAlchemy()

View File

@ -14,48 +14,45 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from ..database import Base
from ..database import db
from ..models import Amount, Posting, Transaction
class GeneralJournalTransaction(Base, Transaction):
class GeneralJournalTransaction(db.Model, Transaction):
__tablename__ = 'general_journal_transactions'
id = Column(Integer, primary_key=True)
id = db.Column(db.Integer, primary_key=True)
dt = Column(DateTime)
description = Column(String)
dt = db.Column(db.DateTime)
description = db.Column(db.String)
postings = relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan')
postings = db.relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan')
class GeneralJournalPosting(Base, Posting):
class GeneralJournalPosting(db.Model, Posting):
__tablename__ = 'general_journal_postings'
id = Column(Integer, primary_key=True)
transaction_id = Column(Integer, ForeignKey('general_journal_transactions.id'))
id = db.Column(db.Integer, primary_key=True)
transaction_id = db.Column(db.Integer, db.ForeignKey('general_journal_transactions.id'))
description = Column(String)
account = Column(String)
quantity = Column(Integer)
commodity = Column(String)
description = db.Column(db.String)
account = db.Column(db.String)
quantity = db.Column(db.Integer)
commodity = db.Column(db.String)
transaction = relationship('GeneralJournalTransaction', back_populates='postings')
transaction = db.relationship('GeneralJournalTransaction', back_populates='postings')
def amount(self):
return Amount(self.quantity, self.commodity)
class BalanceAssertion(Base):
class BalanceAssertion(db.Model):
__tablename__ = 'balance_assertions'
id = Column(Integer, primary_key=True)
id = db.Column(db.Integer, primary_key=True)
dt = Column(DateTime)
description = Column(String)
account = Column(String)
quantity = Column(Integer)
commodity = Column(String)
dt = db.Column(db.DateTime)
description = db.Column(db.String)
account = db.Column(db.String)
quantity = db.Column(db.Integer)
commodity = db.Column(db.String)
def balance(self):
return Amount(self.quantity, self.commodity)

View File

@ -17,7 +17,7 @@
from flask import abort, redirect, render_template, request
from .. import AMOUNT_DPS
from ..database import db_session
from ..database import db
from ..models import TrialBalancer
from ..webapp import all_transactions, app
from .models import Amount, BalanceAssertion, GeneralJournalPosting, GeneralJournalTransaction
@ -58,14 +58,14 @@ def general_journal_new():
transaction.assert_valid()
db_session.add(transaction)
db_session.commit()
db.session.add(transaction)
db.session.commit()
return redirect('/general-journal')
@app.route('/general-journal/edit', methods=['GET', 'POST'])
def general_journal_edit():
transaction = db_session.get(GeneralJournalTransaction, request.args['id'])
transaction = db.session.get(GeneralJournalTransaction, request.args['id'])
if not transaction:
abort(404)
@ -91,7 +91,7 @@ def general_journal_edit():
transaction.assert_valid()
db_session.commit()
db.session.commit()
return redirect('/general-journal')
@ -132,14 +132,14 @@ def balance_assertions_new():
quantity=quantity,
commodity='$'
)
db_session.add(assertion)
db_session.commit()
db.session.add(assertion)
db.session.commit()
return redirect('/balance-assertions')
@app.route('/balance-assertions/edit', methods=['GET', 'POST'])
def balance_assertions_edit():
assertion = db_session.get(BalanceAssertion, request.args['id'])
assertion = db.session.get(BalanceAssertion, request.args['id'])
if not assertion:
abort(404)
@ -156,6 +156,6 @@ def balance_assertions_edit():
assertion.account = request.form['account']
assertion.quantity = quantity
db_session.commit()
db.session.commit()
return redirect('/balance-assertions')

View File

@ -14,25 +14,22 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from ..database import Base, db_session
from ..database import db
from ..models import Amount, Posting, Transaction
class StatementLine(Base):
class StatementLine(db.Model):
__tablename__ = 'statement_lines'
id = Column(Integer, primary_key=True)
id = db.Column(db.Integer, primary_key=True)
source_account = Column(String)
dt = Column(DateTime)
description = Column(String)
quantity = Column(Integer)
balance = Column(Integer)
commodity = Column(String)
source_account = db.Column(db.String)
dt = db.Column(db.DateTime)
description = db.Column(db.String)
quantity = db.Column(db.Integer)
balance = db.Column(db.Integer)
commodity = db.Column(db.String)
postings = relationship('StatementLinePosting', back_populates='statement_line')
postings = db.relationship('StatementLinePosting', back_populates='statement_line')
def amount(self):
return Amount(self.quantity, self.commodity)
@ -69,17 +66,17 @@ class StatementLine(Base):
for posting in self.postings:
# TODO: Will be wonky if transaction covers multiple StatementLines
db_session.delete(posting.transaction)
db.session.delete(posting.transaction)
class StatementLineTransaction(Base, Transaction):
class StatementLineTransaction(db.Model, Transaction):
__tablename__ = 'statement_line_transactions'
id = Column(Integer, primary_key=True)
id = db.Column(db.Integer, primary_key=True)
dt = Column(DateTime)
description = Column(String)
dt = db.Column(db.DateTime)
description = db.Column(db.String)
postings = relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete')
postings = db.relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete')
def charge_account(self, source_account):
if len(self.postings) > 2:
@ -89,20 +86,20 @@ class StatementLineTransaction(Base, Transaction):
if posting.account != source_account:
return posting.account
class StatementLinePosting(Base, Posting):
class StatementLinePosting(db.Model, Posting):
__tablename__ = 'statement_line_postings'
id = Column(Integer, primary_key=True)
transaction_id = Column(Integer, ForeignKey('statement_line_transactions.id'))
line_id = Column(Integer, ForeignKey('statement_lines.id'))
id = db.Column(db.Integer, primary_key=True)
transaction_id = db.Column(db.Integer, db.ForeignKey('statement_line_transactions.id'))
line_id = db.Column(db.Integer, db.ForeignKey('statement_lines.id'))
description = Column(String)
account = Column(String)
quantity = Column(Integer)
commodity = Column(String)
description = db.Column(db.String)
account = db.Column(db.String)
quantity = db.Column(db.Integer)
commodity = db.Column(db.String)
transaction = relationship('StatementLineTransaction', back_populates='postings')
statement_line = relationship('StatementLine', back_populates='postings')
transaction = db.relationship('StatementLineTransaction', back_populates='postings')
statement_line = db.relationship('StatementLine', back_populates='postings')
def amount(self):
return Amount(self.quantity, self.commodity)

View File

@ -17,7 +17,7 @@
from flask import abort, redirect, render_template, request
from .. import AMOUNT_DPS
from ..database import db_session
from ..database import db
from ..webapp import app
from .models import StatementLine, StatementLinePosting, StatementLineTransaction
@ -32,7 +32,7 @@ def statement_lines():
@app.route('/statement-lines/charge', methods=['POST'])
def statement_line_charge():
statement_line = db_session.get(StatementLine, request.form['line-id'])
statement_line = db.session.get(StatementLine, request.form['line-id'])
if not statement_line:
abort(404)
@ -48,14 +48,14 @@ def statement_line_charge():
]
)
db_session.add(transaction)
db_session.commit()
db.session.add(transaction)
db.session.commit()
return 'OK'
@app.route('/statement-lines/edit-transaction', methods=['GET', 'POST'])
def statement_line_edit_transaction():
statement_line = db_session.get(StatementLine, request.args['line-id'])
statement_line = db.session.get(StatementLine, request.args['line-id'])
if not statement_line:
abort(404)
@ -99,8 +99,8 @@ def statement_line_edit_transaction():
)
transaction.assert_valid()
db_session.add(transaction)
db_session.commit()
db.session.add(transaction)
db.session.commit()
return redirect('/statement-lines')
@ -110,8 +110,8 @@ def statement_line_reconcile_transfer():
if len(line_ids) != 2:
raise Exception('Must select exactly 2 statement lines')
line1 = db_session.get(StatementLine, line_ids[0])
line2 = db_session.get(StatementLine, line_ids[1])
line1 = db.session.get(StatementLine, line_ids[0])
line2 = db.session.get(StatementLine, line_ids[1])
# Check same amount
if line1.quantity != -line2.quantity or line1.commodity != line2.commodity:
@ -129,7 +129,7 @@ def statement_line_reconcile_transfer():
StatementLinePosting(statement_line=line2, description=line2.description, account=line2.source_account, quantity=line2.quantity, commodity=line2.commodity)
]
)
db_session.add(transaction)
db_session.commit()
db.session.add(transaction)
db.session.commit()
return redirect('/statement-lines')

View File

@ -16,12 +16,16 @@
from flask import Flask
from .database import db_session, init_db
from .database import db
from .general_journal.models import GeneralJournalTransaction
from .statements.models import StatementLine, StatementLineTransaction
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///drcr.db'
app.config['SQLALCHEMY_RECORD_QUERIES'] = True
db.init_app(app)
def all_transactions():
return (
GeneralJournalTransaction.query.all() +
@ -36,7 +40,7 @@ from .statements import views
@app.cli.command('initdb')
def initdb():
init_db()
db.create_all()
@app.teardown_appcontext
def shutdown_session(exception=None):