diff --git a/.gitignore b/.gitignore
index 84920f9..975657e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,4 @@
__pycache__
/drcr/config.py
-/mydata
-/scripts
+/instance
/venv
-*.db
diff --git a/drcr/database.py b/drcr/database.py
index 251ab4c..60c844a 100644
--- a/drcr/database.py
+++ b/drcr/database.py
@@ -14,17 +14,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from sqlalchemy import create_engine
-from sqlalchemy.orm import scoped_session, sessionmaker
-from sqlalchemy.ext.declarative import declarative_base
+from flask_sqlalchemy import SQLAlchemy
-engine = create_engine('sqlite:///drcr.db')
-db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine))
-Base = declarative_base()
-Base.query = db_session.query_property()
-
-def init_db():
- from .general_journal import models
- from .statements import models
-
- Base.metadata.create_all(bind=engine)
+db = SQLAlchemy()
diff --git a/drcr/general_journal/models.py b/drcr/general_journal/models.py
index 359216f..33fb528 100644
--- a/drcr/general_journal/models.py
+++ b/drcr/general_journal/models.py
@@ -14,48 +14,45 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
-from sqlalchemy.orm import relationship
-
-from ..database import Base
+from ..database import db
from ..models import Amount, Posting, Transaction
-class GeneralJournalTransaction(Base, Transaction):
+class GeneralJournalTransaction(db.Model, Transaction):
__tablename__ = 'general_journal_transactions'
- id = Column(Integer, primary_key=True)
+ id = db.Column(db.Integer, primary_key=True)
- dt = Column(DateTime)
- description = Column(String)
+ dt = db.Column(db.DateTime)
+ description = db.Column(db.String)
- postings = relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan')
+ postings = db.relationship('GeneralJournalPosting', back_populates='transaction', cascade='all, delete-orphan')
-class GeneralJournalPosting(Base, Posting):
+class GeneralJournalPosting(db.Model, Posting):
__tablename__ = 'general_journal_postings'
- id = Column(Integer, primary_key=True)
- transaction_id = Column(Integer, ForeignKey('general_journal_transactions.id'))
+ id = db.Column(db.Integer, primary_key=True)
+ transaction_id = db.Column(db.Integer, db.ForeignKey('general_journal_transactions.id'))
- description = Column(String)
- account = Column(String)
- quantity = Column(Integer)
- commodity = Column(String)
+ description = db.Column(db.String)
+ account = db.Column(db.String)
+ quantity = db.Column(db.Integer)
+ commodity = db.Column(db.String)
- transaction = relationship('GeneralJournalTransaction', back_populates='postings')
+ transaction = db.relationship('GeneralJournalTransaction', back_populates='postings')
def amount(self):
return Amount(self.quantity, self.commodity)
-class BalanceAssertion(Base):
+class BalanceAssertion(db.Model):
__tablename__ = 'balance_assertions'
- id = Column(Integer, primary_key=True)
+ id = db.Column(db.Integer, primary_key=True)
- dt = Column(DateTime)
- description = Column(String)
- account = Column(String)
- quantity = Column(Integer)
- commodity = Column(String)
+ dt = db.Column(db.DateTime)
+ description = db.Column(db.String)
+ account = db.Column(db.String)
+ quantity = db.Column(db.Integer)
+ commodity = db.Column(db.String)
def balance(self):
return Amount(self.quantity, self.commodity)
diff --git a/drcr/general_journal/views.py b/drcr/general_journal/views.py
index 91f340c..748f1ed 100644
--- a/drcr/general_journal/views.py
+++ b/drcr/general_journal/views.py
@@ -17,7 +17,7 @@
from flask import abort, redirect, render_template, request
from .. import AMOUNT_DPS
-from ..database import db_session
+from ..database import db
from ..models import TrialBalancer
from ..webapp import all_transactions, app
from .models import Amount, BalanceAssertion, GeneralJournalPosting, GeneralJournalTransaction
@@ -58,14 +58,14 @@ def general_journal_new():
transaction.assert_valid()
- db_session.add(transaction)
- db_session.commit()
+ db.session.add(transaction)
+ db.session.commit()
return redirect('/general-journal')
@app.route('/general-journal/edit', methods=['GET', 'POST'])
def general_journal_edit():
- transaction = db_session.get(GeneralJournalTransaction, request.args['id'])
+ transaction = db.session.get(GeneralJournalTransaction, request.args['id'])
if not transaction:
abort(404)
@@ -91,7 +91,7 @@ def general_journal_edit():
transaction.assert_valid()
- db_session.commit()
+ db.session.commit()
return redirect('/general-journal')
@@ -132,14 +132,14 @@ def balance_assertions_new():
quantity=quantity,
commodity='$'
)
- db_session.add(assertion)
- db_session.commit()
+ db.session.add(assertion)
+ db.session.commit()
return redirect('/balance-assertions')
@app.route('/balance-assertions/edit', methods=['GET', 'POST'])
def balance_assertions_edit():
- assertion = db_session.get(BalanceAssertion, request.args['id'])
+ assertion = db.session.get(BalanceAssertion, request.args['id'])
if not assertion:
abort(404)
@@ -156,6 +156,6 @@ def balance_assertions_edit():
assertion.account = request.form['account']
assertion.quantity = quantity
- db_session.commit()
+ db.session.commit()
return redirect('/balance-assertions')
diff --git a/drcr/statements/models.py b/drcr/statements/models.py
index 1d6b61d..951865b 100644
--- a/drcr/statements/models.py
+++ b/drcr/statements/models.py
@@ -14,25 +14,22 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
-from sqlalchemy.orm import relationship
-
-from ..database import Base, db_session
+from ..database import db
from ..models import Amount, Posting, Transaction
-class StatementLine(Base):
+class StatementLine(db.Model):
__tablename__ = 'statement_lines'
- id = Column(Integer, primary_key=True)
+ id = db.Column(db.Integer, primary_key=True)
- source_account = Column(String)
- dt = Column(DateTime)
- description = Column(String)
- quantity = Column(Integer)
- balance = Column(Integer)
- commodity = Column(String)
+ source_account = db.Column(db.String)
+ dt = db.Column(db.DateTime)
+ description = db.Column(db.String)
+ quantity = db.Column(db.Integer)
+ balance = db.Column(db.Integer)
+ commodity = db.Column(db.String)
- postings = relationship('StatementLinePosting', back_populates='statement_line')
+ postings = db.relationship('StatementLinePosting', back_populates='statement_line')
def amount(self):
return Amount(self.quantity, self.commodity)
@@ -69,17 +66,17 @@ class StatementLine(Base):
for posting in self.postings:
# TODO: Will be wonky if transaction covers multiple StatementLines
- db_session.delete(posting.transaction)
+ db.session.delete(posting.transaction)
-class StatementLineTransaction(Base, Transaction):
+class StatementLineTransaction(db.Model, Transaction):
__tablename__ = 'statement_line_transactions'
- id = Column(Integer, primary_key=True)
+ id = db.Column(db.Integer, primary_key=True)
- dt = Column(DateTime)
- description = Column(String)
+ dt = db.Column(db.DateTime)
+ description = db.Column(db.String)
- postings = relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete')
+ postings = db.relationship('StatementLinePosting', back_populates='transaction', cascade='all, delete')
def charge_account(self, source_account):
if len(self.postings) > 2:
@@ -89,20 +86,20 @@ class StatementLineTransaction(Base, Transaction):
if posting.account != source_account:
return posting.account
-class StatementLinePosting(Base, Posting):
+class StatementLinePosting(db.Model, Posting):
__tablename__ = 'statement_line_postings'
- id = Column(Integer, primary_key=True)
- transaction_id = Column(Integer, ForeignKey('statement_line_transactions.id'))
- line_id = Column(Integer, ForeignKey('statement_lines.id'))
+ id = db.Column(db.Integer, primary_key=True)
+ transaction_id = db.Column(db.Integer, db.ForeignKey('statement_line_transactions.id'))
+ line_id = db.Column(db.Integer, db.ForeignKey('statement_lines.id'))
- description = Column(String)
- account = Column(String)
- quantity = Column(Integer)
- commodity = Column(String)
+ description = db.Column(db.String)
+ account = db.Column(db.String)
+ quantity = db.Column(db.Integer)
+ commodity = db.Column(db.String)
- transaction = relationship('StatementLineTransaction', back_populates='postings')
- statement_line = relationship('StatementLine', back_populates='postings')
+ transaction = db.relationship('StatementLineTransaction', back_populates='postings')
+ statement_line = db.relationship('StatementLine', back_populates='postings')
def amount(self):
return Amount(self.quantity, self.commodity)
diff --git a/drcr/statements/views.py b/drcr/statements/views.py
index 50d1d91..8b3b288 100644
--- a/drcr/statements/views.py
+++ b/drcr/statements/views.py
@@ -17,7 +17,7 @@
from flask import abort, redirect, render_template, request
from .. import AMOUNT_DPS
-from ..database import db_session
+from ..database import db
from ..webapp import app
from .models import StatementLine, StatementLinePosting, StatementLineTransaction
@@ -32,7 +32,7 @@ def statement_lines():
@app.route('/statement-lines/charge', methods=['POST'])
def statement_line_charge():
- statement_line = db_session.get(StatementLine, request.form['line-id'])
+ statement_line = db.session.get(StatementLine, request.form['line-id'])
if not statement_line:
abort(404)
@@ -48,14 +48,14 @@ def statement_line_charge():
]
)
- db_session.add(transaction)
- db_session.commit()
+ db.session.add(transaction)
+ db.session.commit()
return 'OK'
@app.route('/statement-lines/edit-transaction', methods=['GET', 'POST'])
def statement_line_edit_transaction():
- statement_line = db_session.get(StatementLine, request.args['line-id'])
+ statement_line = db.session.get(StatementLine, request.args['line-id'])
if not statement_line:
abort(404)
@@ -99,8 +99,8 @@ def statement_line_edit_transaction():
)
transaction.assert_valid()
- db_session.add(transaction)
- db_session.commit()
+ db.session.add(transaction)
+ db.session.commit()
return redirect('/statement-lines')
@@ -110,8 +110,8 @@ def statement_line_reconcile_transfer():
if len(line_ids) != 2:
raise Exception('Must select exactly 2 statement lines')
- line1 = db_session.get(StatementLine, line_ids[0])
- line2 = db_session.get(StatementLine, line_ids[1])
+ line1 = db.session.get(StatementLine, line_ids[0])
+ line2 = db.session.get(StatementLine, line_ids[1])
# Check same amount
if line1.quantity != -line2.quantity or line1.commodity != line2.commodity:
@@ -129,7 +129,7 @@ def statement_line_reconcile_transfer():
StatementLinePosting(statement_line=line2, description=line2.description, account=line2.source_account, quantity=line2.quantity, commodity=line2.commodity)
]
)
- db_session.add(transaction)
- db_session.commit()
+ db.session.add(transaction)
+ db.session.commit()
return redirect('/statement-lines')
diff --git a/drcr/webapp.py b/drcr/webapp.py
index 70d22ae..91fe586 100644
--- a/drcr/webapp.py
+++ b/drcr/webapp.py
@@ -16,12 +16,16 @@
from flask import Flask
-from .database import db_session, init_db
+from .database import db
from .general_journal.models import GeneralJournalTransaction
from .statements.models import StatementLine, StatementLineTransaction
app = Flask(__name__)
+app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///drcr.db'
+app.config['SQLALCHEMY_RECORD_QUERIES'] = True
+db.init_app(app)
+
def all_transactions():
return (
GeneralJournalTransaction.query.all() +
@@ -36,7 +40,7 @@ from .statements import views
@app.cli.command('initdb')
def initdb():
- init_db()
+ db.create_all()
@app.teardown_appcontext
def shutdown_session(exception=None):