diff --git a/eos/base/election.py b/eos/base/election.py
index fa520bd..cea1e09 100644
--- a/eos/base/election.py
+++ b/eos/base/election.py
@@ -25,11 +25,11 @@ class PlaintextBallotQuestion(BallotQuestion):
class Ballot(EmbeddedObject):
_id = UUIDField()
- questions = EmbeddedObjectListField(BallotQuestion)
+ questions = EmbeddedObjectListField()
class Voter(EmbeddedObject):
_id = UUIDField()
- ballots = EmbeddedObjectListField(Ballot)
+ ballots = EmbeddedObjectListField()
class Question(EmbeddedObject):
prompt = StringField()
@@ -39,7 +39,7 @@ class ApprovalQuestion(Question):
class Election(TopLevelObject):
_id = UUIDField()
- workflow = EmbeddedObjectField(Workflow)
+ workflow = EmbeddedObjectField(Workflow) # Once saved, we don't care what kind of workflow it is
name = StringField()
- voters = EmbeddedObjectListField(Voter, hashed=False)
- questions = EmbeddedObjectListField(Question)
+ voters = EmbeddedObjectListField(hashed=False)
+ questions = EmbeddedObjectListField()
diff --git a/eos/base/tests.py b/eos/base/tests.py
index 6d443e6..c6ca686 100644
--- a/eos/base/tests.py
+++ b/eos/base/tests.py
@@ -50,15 +50,10 @@ class ElectionTestCase(TestCase):
question.choices.append('Andrew Citizen')
election.questions.append(question)
- #election.save()
+ election.save()
# Check that it saved
- #self.assertEqual(Election.objects.get(0)._id, election._id) # TODO: Compare JSON
-
- # Retrieve from scratch, too
- #self.db.close()
- #self.db = mongoengine.connect('test')
- #self.assertEqual(Election.objects.get(0)._id, election._id)
+ self.assertEqual(db[Election._name].find_one(), election.serialise())
# Freeze election
self.assertEqual(election.workflow.get_task('eos.base.workflow.TaskConfigureElection').status, WorkflowTask.Status.READY)
diff --git a/eos/base/workflow.py b/eos/base/workflow.py
index d89cea2..fd34e19 100644
--- a/eos/base/workflow.py
+++ b/eos/base/workflow.py
@@ -74,7 +74,7 @@ class WorkflowTask(EmbeddedObject):
self.fire_event('exit')
class Workflow(EmbeddedObject):
- tasks = EmbeddedObjectListField(WorkflowTask)
+ tasks = EmbeddedObjectListField()
meta = {
'abstract': True
}
diff --git a/eos/core/objects/python.py b/eos/core/objects/python.py
index 49cf97c..34909d6 100644
--- a/eos/core/objects/python.py
+++ b/eos/core/objects/python.py
@@ -15,6 +15,7 @@
# along with this program. If not, see .
import pymongo
+from bson.binary import UUIDLegacy
import uuid
@@ -50,7 +51,7 @@ class EmbeddedObjectField(Field):
self.object_type = object_type
def serialise(self, value):
- return value.serialise_and_wrap(self.object_type)
+ return EosObject.serialise_and_wrap(value, self.object_type)
def deserialise(self, value):
return EosObject.deserialise_and_unwrap(value, self.object_type)
@@ -66,28 +67,59 @@ class ListField(Field):
def deserialise(self, value):
return [self.element_field.deserialise(x) for x in value]
-EmbeddedObjectListField = ListField
+class EmbeddedObjectListField(Field):
+ def __init__(self, object_type=None, *args, **kwargs):
+ super().__init__(default=[], *args, **kwargs)
+ self.object_type = object_type
+
+ def serialise(self, value):
+ return [EosObject.serialise_and_wrap(x, self.object_type) for x in value]
+
+ def deserialise(self, value):
+ return [EosObject.deserialise_and_unwrap(x, self.object_type) for x in value]
class UUIDField(Field):
def __init__(self, *args, **kwargs):
super().__init__(default=uuid.uuid4, *args, **kwargs)
def serialise(self, value):
- return str(uuid.uuid4)
+ return str(value)
def unserialise(self, value):
- return uuid.uuid4(value)
+ return uuid.UUID(value)
# Objects
# =======
class EosObjectType(type):
def __new__(meta, name, bases, attrs):
- #meta, name, bases, attrs = meta.before_new(meta, name, bases, attrs)
cls = type.__new__(meta, name, bases, attrs)
+ cls._name = cls.__module__ + '.' + cls.__qualname__
+ if name != 'EosObject':
+ EosObject.objects[cls._name] = cls
+ return cls
+
+class EosObject(metaclass=EosObjectType):
+ objects = {}
+
+ @staticmethod
+ def serialise_and_wrap(value, object_type=None):
+ if object_type:
+ return value.serialise()
+ return {'type': value._name, 'value': value.serialise()}
+
+ @staticmethod
+ def deserialise_and_unwrap(value, object_type=None):
+ if object_type:
+ return object_type.deserialise(value)
+ return EosObject.objects[value['type']].deserialise(value['value'])
+
+class DocumentObjectType(EosObjectType):
+ def __new__(meta, name, bases, attrs):
+ cls = EosObjectType.__new__(meta, name, bases, attrs)
# Process fields
- fields = cls._fields if hasattr(cls, '_fields') else {}
+ fields = cls._fields.copy() if hasattr(cls, '_fields') else {} # remember to .copy() XD
for attr in list(dir(cls)):
val = getattr(cls, attr)
if isinstance(val, Field):
@@ -95,14 +127,29 @@ class EosObjectType(type):
delattr(cls, attr)
cls._fields = fields
- cls._name = cls.__module__ + '.' + cls.__qualname__
-
return cls
-class EosObject(metaclass=EosObjectType):
+class DocumentObject(metaclass=DocumentObjectType):
def __init__(self, *args, **kwargs):
for attr, val in self._fields.items():
- setattr(self, attr, kwargs.get(attr, val.default))
+ if attr in kwargs:
+ setattr(self, attr, kwargs[attr])
+ else:
+ default = val.default
+ if callable(default):
+ default = default()
+ setattr(self, attr, default)
+
+ def serialise(self):
+ return {attr: val.serialise(getattr(self, attr)) for attr, val in self._fields.items()}
+
+ @classmethod
+ def deserialise(cls, value):
+ return cls(**value) # wew
-TopLevelObject = EosObject
-EmbeddedObject = EosObject
+class TopLevelObject(DocumentObject):
+ def save(self):
+ res = db[self._name].replace_one({'_id': self.serialise()['_id']}, self.serialise(), upsert=True)
+
+class EmbeddedObject(DocumentObject):
+ pass
diff --git a/eos/core/tests.py b/eos/core/tests.py
index 14a6f4a..9bbde91 100644
--- a/eos/core/tests.py
+++ b/eos/core/tests.py
@@ -19,14 +19,15 @@ from unittest import TestCase
from eos.core.objects import *
class PyTestCase(TestCase):
- def setUp(self):
+ @classmethod
+ def setUpClass(cls):
class Person(TopLevelObject):
name = StringField()
address = StringField(default=None)
def say_hi(self):
return 'Hello! My name is ' + self.name
- self.Person = Person
+ cls.Person = Person
def test_basic_py(self):
person1 = self.Person(name='John', address='Address 1')
@@ -36,3 +37,11 @@ class PyTestCase(TestCase):
self.assertEqual(person2.address, 'Address 2')
self.assertEqual(person1.say_hi(), 'Hello! My name is John')
self.assertEqual(person2.say_hi(), 'Hello! My name is James')
+
+ def test_serialise_py(self):
+ person1 = self.Person(name='John', address='Address 1')
+ expect1 = {'name': 'John', 'address': 'Address 1'}
+
+ self.assertEqual(person1.serialise(), expect1)
+ self.assertEqual(EosObject.serialise_and_wrap(person1, self.Person), expect1)
+ self.assertEqual(EosObject.serialise_and_wrap(person1), {'type': 'eos.core.tests.PyTestCase.setUpClass..Person', 'value': expect1})
diff --git a/test.sh b/test.sh
new file mode 100755
index 0000000..4574fed
--- /dev/null
+++ b/test.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+ARGS=-vvv
+
+for test in eos.core.tests eos.base.tests; do
+ python -m unittest $test $ARGS || exit 1
+done