diff --git a/eos/base/tests.py b/eos/base/tests.py
index 2c6be96..45e2dff 100644
--- a/eos/base/tests.py
+++ b/eos/base/tests.py
@@ -24,7 +24,7 @@ class ElectionTestCase(EosTestCase):
@classmethod
def setUpClass(cls):
db_connect('test')
- dbinfo.client.drop_database('test')
+ dbinfo.provider.reset_db()
def do_task_assert(self, election, task, next_task):
self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY)
@@ -66,8 +66,7 @@ class ElectionTestCase(EosTestCase):
election.save()
# Check that it saved
- self.assertEqual(dbinfo.db[Election._db_name].find_one()['value'], election.serialise())
- self.assertEqual(EosObject.deserialise_and_unwrap(dbinfo.db[Election._db_name].find_one()).serialise(), election.serialise())
+ self.assertEqual(Election.get_all()[0], election)
self.assertEqualJSON(EosObject.deserialise_and_unwrap(EosObject.serialise_and_wrap(election)).serialise(), election.serialise())
diff --git a/eos/core/db/__init__.py b/eos/core/db/__init__.py
new file mode 100644
index 0000000..61c5249
--- /dev/null
+++ b/eos/core/db/__init__.py
@@ -0,0 +1,22 @@
+# Eos - Verifiable elections
+# Copyright © 2017 RunasSudo (Yingtong Li)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+db_providers = {}
+
+class DBProvider:
+ def __init__(self, db_name, db_uri):
+ self.db_name = db_name
+ self.db_uri = db_uri
diff --git a/eos/core/db/mongodb.py b/eos/core/db/mongodb.py
new file mode 100644
index 0000000..b25c2e3
--- /dev/null
+++ b/eos/core/db/mongodb.py
@@ -0,0 +1,38 @@
+# Eos - Verifiable elections
+# Copyright © 2017 RunasSudo (Yingtong Li)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import pymongo
+
+import eos.core.db
+
+class MongoDBProvider(eos.core.db.DBProvider):
+ def connect(self):
+ self.client = pymongo.MongoClient(self.db_uri)
+ self.db = self.client[self.db_name]
+
+ def get_all(self, collection):
+ return self.db[collection].find()
+
+ def get_by_id(self, collection, _id):
+ return self.db[collection].find_one(_id)
+
+ def update_by_id(self, collection, _id, value):
+ self.db[collection].replace_one({'_id': _id}, value, upsert=True)
+
+ def reset_db(self):
+ self.client.drop_database(self.db_name)
+
+eos.core.db.db_providers['mongodb'] = MongoDBProvider
diff --git a/eos/core/objects/__init__.py b/eos/core/objects/__init__.py
index 396c083..f847dec 100644
--- a/eos/core/objects/__init__.py
+++ b/eos/core/objects/__init__.py
@@ -27,7 +27,9 @@ except:
if is_python:
__pragma__('skip')
- import pymongo
+ import eos.core.db
+ import eos.core.db.mongodb
+
from bson.binary import UUIDLegacy
import base64
@@ -51,14 +53,13 @@ else:
class DBInfo:
def __init__(self):
- self.client = None
- self.db = None
+ self.provider = None
dbinfo = DBInfo()
-def db_connect(db_name, mongo_uri='mongodb://localhost:27017/'):
- dbinfo.client = pymongo.MongoClient(mongo_uri)
- dbinfo.db = dbinfo.client[db_name]
+def db_connect(db_name, db_uri='mongodb://localhost:27017/', db_type='mongodb'):
+ dbinfo.provider = eos.core.db.db_providers[db_type](db_name, db_uri)
+ dbinfo.provider.connect()
# Fields
# ======
@@ -376,15 +377,16 @@ class DocumentObject(EosObject, metaclass=DocumentObjectType):
class TopLevelObject(DocumentObject):
def save(self):
#res = db[self._name].replace_one({'_id': self.serialise()['_id']}, self.serialise(), upsert=True)
- res = dbinfo.db[self._db_name].replace_one({'_id': self._fields['_id'].serialise(self._id)}, EosObject.serialise_and_wrap(self), upsert=True)
+ #res = dbinfo.db[self._db_name].replace_one({'_id': self._fields['_id'].serialise(self._id)}, EosObject.serialise_and_wrap(self), upsert=True)
+ dbinfo.provider.update_by_id(self._db_name, self._fields['_id'].serialise(self._id), EosObject.serialise_and_wrap(self))
@classmethod
def get_all(cls):
- return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.db[cls._db_name].find()]
+ return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all(cls._db_name)]
@classmethod
def get_by_id(cls, _id):
- return EosObject.deserialise_and_unwrap(dbinfo.db[cls._db_name].find_one(_id))
+ return EosObject.deserialise_and_unwrap(dbinfo.provider.get_by_id(cls._db_name, _id))
class EmbeddedObject(DocumentObject):
pass
diff --git a/eos/psr/tests.py b/eos/psr/tests.py
index 0718012..c0f161f 100644
--- a/eos/psr/tests.py
+++ b/eos/psr/tests.py
@@ -213,7 +213,7 @@ class ElectionTestCase(EosTestCase):
@classmethod
def setUpClass(cls):
db_connect('test')
- dbinfo.client.drop_database('test')
+ dbinfo.provider.reset_db()
def do_task_assert(self, election, task, next_task):
self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY)
diff --git a/eosweb/core/main.py b/eosweb/core/main.py
index 05ade68..16a4e08 100644
--- a/eosweb/core/main.py
+++ b/eosweb/core/main.py
@@ -48,7 +48,7 @@ if 'EOSWEB_SETTINGS' in os.environ:
app.config.from_envvar('EOSWEB_SETTINGS')
# Connect to database
-db_connect(app.config['DB_NAME'], app.config['MONGO_URI'])
+db_connect(app.config['DB_NAME'], app.config['DB_URI'], app.config['DB_TYPE'])
# Make Flask's serialisation, e.g. for sessions, EosObject aware
class EosObjectJSONEncoder(flask.json.JSONEncoder):
@@ -83,7 +83,7 @@ def run_tests(prefix, lang):
@app.cli.command('drop_db_and_setup')
def setup_test_election():
# DANGER!
- dbinfo.client.drop_database(app.config['DB_NAME'])
+ dbinfo.provider.reset_db()
# Set up election
election = PSRElection()
diff --git a/eosweb/core/settings.py b/eosweb/core/settings.py
index e3f6113..519dd38 100644
--- a/eosweb/core/settings.py
+++ b/eosweb/core/settings.py
@@ -18,7 +18,8 @@ ORG_NAME = 'FIXME'
BASE_URI = 'http://localhost:5000'
-MONGO_URI = 'mongodb://localhost:27017/'
+DB_TYPE = 'mongodb'
+DB_URI = 'mongodb://localhost:27017/'
DB_NAME = 'eos'
SECRET_KEY = 'FIXME'
diff --git a/local_settings.example.py b/local_settings.example.py
index 7f98033..3405a85 100644
--- a/local_settings.example.py
+++ b/local_settings.example.py
@@ -7,6 +7,12 @@ AUTH_METHODS = [
('reddit', 'Reddit')
]
+# MongoDB
+
+DB_TYPE = 'mongodb'
+DB_URI = 'mongodb://localhost:27017/'
+DB_NAME = 'eos'
+
# Email
SMTP_HOST, SMTP_PORT = 'localhost', 25