From 72e08c01ff28dda722270503c6d6ad2bd701c068 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Mon, 27 Nov 2017 19:40:01 +1100 Subject: [PATCH] Abstract away DB implementation --- eos/base/tests.py | 5 ++--- eos/core/db/__init__.py | 22 +++++++++++++++++++++ eos/core/db/mongodb.py | 38 ++++++++++++++++++++++++++++++++++++ eos/core/objects/__init__.py | 20 ++++++++++--------- eos/psr/tests.py | 2 +- eosweb/core/main.py | 4 ++-- eosweb/core/settings.py | 3 ++- local_settings.example.py | 6 ++++++ 8 files changed, 84 insertions(+), 16 deletions(-) create mode 100644 eos/core/db/__init__.py create mode 100644 eos/core/db/mongodb.py diff --git a/eos/base/tests.py b/eos/base/tests.py index 2c6be96..45e2dff 100644 --- a/eos/base/tests.py +++ b/eos/base/tests.py @@ -24,7 +24,7 @@ class ElectionTestCase(EosTestCase): @classmethod def setUpClass(cls): db_connect('test') - dbinfo.client.drop_database('test') + dbinfo.provider.reset_db() def do_task_assert(self, election, task, next_task): self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY) @@ -66,8 +66,7 @@ class ElectionTestCase(EosTestCase): election.save() # Check that it saved - self.assertEqual(dbinfo.db[Election._db_name].find_one()['value'], election.serialise()) - self.assertEqual(EosObject.deserialise_and_unwrap(dbinfo.db[Election._db_name].find_one()).serialise(), election.serialise()) + self.assertEqual(Election.get_all()[0], election) self.assertEqualJSON(EosObject.deserialise_and_unwrap(EosObject.serialise_and_wrap(election)).serialise(), election.serialise()) diff --git a/eos/core/db/__init__.py b/eos/core/db/__init__.py new file mode 100644 index 0000000..61c5249 --- /dev/null +++ b/eos/core/db/__init__.py @@ -0,0 +1,22 @@ +# Eos - Verifiable elections +# Copyright © 2017 RunasSudo (Yingtong Li) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +db_providers = {} + +class DBProvider: + def __init__(self, db_name, db_uri): + self.db_name = db_name + self.db_uri = db_uri diff --git a/eos/core/db/mongodb.py b/eos/core/db/mongodb.py new file mode 100644 index 0000000..b25c2e3 --- /dev/null +++ b/eos/core/db/mongodb.py @@ -0,0 +1,38 @@ +# Eos - Verifiable elections +# Copyright © 2017 RunasSudo (Yingtong Li) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import pymongo + +import eos.core.db + +class MongoDBProvider(eos.core.db.DBProvider): + def connect(self): + self.client = pymongo.MongoClient(self.db_uri) + self.db = self.client[self.db_name] + + def get_all(self, collection): + return self.db[collection].find() + + def get_by_id(self, collection, _id): + return self.db[collection].find_one(_id) + + def update_by_id(self, collection, _id, value): + self.db[collection].replace_one({'_id': _id}, value, upsert=True) + + def reset_db(self): + self.client.drop_database(self.db_name) + +eos.core.db.db_providers['mongodb'] = MongoDBProvider diff --git a/eos/core/objects/__init__.py b/eos/core/objects/__init__.py index 396c083..f847dec 100644 --- a/eos/core/objects/__init__.py +++ b/eos/core/objects/__init__.py @@ -27,7 +27,9 @@ except: if is_python: __pragma__('skip') - import pymongo + import eos.core.db + import eos.core.db.mongodb + from bson.binary import UUIDLegacy import base64 @@ -51,14 +53,13 @@ else: class DBInfo: def __init__(self): - self.client = None - self.db = None + self.provider = None dbinfo = DBInfo() -def db_connect(db_name, mongo_uri='mongodb://localhost:27017/'): - dbinfo.client = pymongo.MongoClient(mongo_uri) - dbinfo.db = dbinfo.client[db_name] +def db_connect(db_name, db_uri='mongodb://localhost:27017/', db_type='mongodb'): + dbinfo.provider = eos.core.db.db_providers[db_type](db_name, db_uri) + dbinfo.provider.connect() # Fields # ====== @@ -376,15 +377,16 @@ class DocumentObject(EosObject, metaclass=DocumentObjectType): class TopLevelObject(DocumentObject): def save(self): #res = db[self._name].replace_one({'_id': self.serialise()['_id']}, self.serialise(), upsert=True) - res = dbinfo.db[self._db_name].replace_one({'_id': self._fields['_id'].serialise(self._id)}, EosObject.serialise_and_wrap(self), upsert=True) + #res = dbinfo.db[self._db_name].replace_one({'_id': self._fields['_id'].serialise(self._id)}, EosObject.serialise_and_wrap(self), upsert=True) + dbinfo.provider.update_by_id(self._db_name, self._fields['_id'].serialise(self._id), EosObject.serialise_and_wrap(self)) @classmethod def get_all(cls): - return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.db[cls._db_name].find()] + return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all(cls._db_name)] @classmethod def get_by_id(cls, _id): - return EosObject.deserialise_and_unwrap(dbinfo.db[cls._db_name].find_one(_id)) + return EosObject.deserialise_and_unwrap(dbinfo.provider.get_by_id(cls._db_name, _id)) class EmbeddedObject(DocumentObject): pass diff --git a/eos/psr/tests.py b/eos/psr/tests.py index 0718012..c0f161f 100644 --- a/eos/psr/tests.py +++ b/eos/psr/tests.py @@ -213,7 +213,7 @@ class ElectionTestCase(EosTestCase): @classmethod def setUpClass(cls): db_connect('test') - dbinfo.client.drop_database('test') + dbinfo.provider.reset_db() def do_task_assert(self, election, task, next_task): self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY) diff --git a/eosweb/core/main.py b/eosweb/core/main.py index 05ade68..16a4e08 100644 --- a/eosweb/core/main.py +++ b/eosweb/core/main.py @@ -48,7 +48,7 @@ if 'EOSWEB_SETTINGS' in os.environ: app.config.from_envvar('EOSWEB_SETTINGS') # Connect to database -db_connect(app.config['DB_NAME'], app.config['MONGO_URI']) +db_connect(app.config['DB_NAME'], app.config['DB_URI'], app.config['DB_TYPE']) # Make Flask's serialisation, e.g. for sessions, EosObject aware class EosObjectJSONEncoder(flask.json.JSONEncoder): @@ -83,7 +83,7 @@ def run_tests(prefix, lang): @app.cli.command('drop_db_and_setup') def setup_test_election(): # DANGER! - dbinfo.client.drop_database(app.config['DB_NAME']) + dbinfo.provider.reset_db() # Set up election election = PSRElection() diff --git a/eosweb/core/settings.py b/eosweb/core/settings.py index e3f6113..519dd38 100644 --- a/eosweb/core/settings.py +++ b/eosweb/core/settings.py @@ -18,7 +18,8 @@ ORG_NAME = 'FIXME' BASE_URI = 'http://localhost:5000' -MONGO_URI = 'mongodb://localhost:27017/' +DB_TYPE = 'mongodb' +DB_URI = 'mongodb://localhost:27017/' DB_NAME = 'eos' SECRET_KEY = 'FIXME' diff --git a/local_settings.example.py b/local_settings.example.py index 7f98033..3405a85 100644 --- a/local_settings.example.py +++ b/local_settings.example.py @@ -7,6 +7,12 @@ AUTH_METHODS = [ ('reddit', 'Reddit') ] +# MongoDB + +DB_TYPE = 'mongodb' +DB_URI = 'mongodb://localhost:27017/' +DB_NAME = 'eos' + # Email SMTP_HOST, SMTP_PORT = 'localhost', 25