diff --git a/eos/base/election.py b/eos/base/election.py index fa520bd..cea1e09 100644 --- a/eos/base/election.py +++ b/eos/base/election.py @@ -25,11 +25,11 @@ class PlaintextBallotQuestion(BallotQuestion): class Ballot(EmbeddedObject): _id = UUIDField() - questions = EmbeddedObjectListField(BallotQuestion) + questions = EmbeddedObjectListField() class Voter(EmbeddedObject): _id = UUIDField() - ballots = EmbeddedObjectListField(Ballot) + ballots = EmbeddedObjectListField() class Question(EmbeddedObject): prompt = StringField() @@ -39,7 +39,7 @@ class ApprovalQuestion(Question): class Election(TopLevelObject): _id = UUIDField() - workflow = EmbeddedObjectField(Workflow) + workflow = EmbeddedObjectField(Workflow) # Once saved, we don't care what kind of workflow it is name = StringField() - voters = EmbeddedObjectListField(Voter, hashed=False) - questions = EmbeddedObjectListField(Question) + voters = EmbeddedObjectListField(hashed=False) + questions = EmbeddedObjectListField() diff --git a/eos/base/tests.py b/eos/base/tests.py index 6d443e6..c6ca686 100644 --- a/eos/base/tests.py +++ b/eos/base/tests.py @@ -50,15 +50,10 @@ class ElectionTestCase(TestCase): question.choices.append('Andrew Citizen') election.questions.append(question) - #election.save() + election.save() # Check that it saved - #self.assertEqual(Election.objects.get(0)._id, election._id) # TODO: Compare JSON - - # Retrieve from scratch, too - #self.db.close() - #self.db = mongoengine.connect('test') - #self.assertEqual(Election.objects.get(0)._id, election._id) + self.assertEqual(db[Election._name].find_one(), election.serialise()) # Freeze election self.assertEqual(election.workflow.get_task('eos.base.workflow.TaskConfigureElection').status, WorkflowTask.Status.READY) diff --git a/eos/base/workflow.py b/eos/base/workflow.py index d89cea2..fd34e19 100644 --- a/eos/base/workflow.py +++ b/eos/base/workflow.py @@ -74,7 +74,7 @@ class WorkflowTask(EmbeddedObject): self.fire_event('exit') class Workflow(EmbeddedObject): - tasks = EmbeddedObjectListField(WorkflowTask) + tasks = EmbeddedObjectListField() meta = { 'abstract': True } diff --git a/eos/core/objects/python.py b/eos/core/objects/python.py index 49cf97c..34909d6 100644 --- a/eos/core/objects/python.py +++ b/eos/core/objects/python.py @@ -15,6 +15,7 @@ # along with this program. If not, see . import pymongo +from bson.binary import UUIDLegacy import uuid @@ -50,7 +51,7 @@ class EmbeddedObjectField(Field): self.object_type = object_type def serialise(self, value): - return value.serialise_and_wrap(self.object_type) + return EosObject.serialise_and_wrap(value, self.object_type) def deserialise(self, value): return EosObject.deserialise_and_unwrap(value, self.object_type) @@ -66,28 +67,59 @@ class ListField(Field): def deserialise(self, value): return [self.element_field.deserialise(x) for x in value] -EmbeddedObjectListField = ListField +class EmbeddedObjectListField(Field): + def __init__(self, object_type=None, *args, **kwargs): + super().__init__(default=[], *args, **kwargs) + self.object_type = object_type + + def serialise(self, value): + return [EosObject.serialise_and_wrap(x, self.object_type) for x in value] + + def deserialise(self, value): + return [EosObject.deserialise_and_unwrap(x, self.object_type) for x in value] class UUIDField(Field): def __init__(self, *args, **kwargs): super().__init__(default=uuid.uuid4, *args, **kwargs) def serialise(self, value): - return str(uuid.uuid4) + return str(value) def unserialise(self, value): - return uuid.uuid4(value) + return uuid.UUID(value) # Objects # ======= class EosObjectType(type): def __new__(meta, name, bases, attrs): - #meta, name, bases, attrs = meta.before_new(meta, name, bases, attrs) cls = type.__new__(meta, name, bases, attrs) + cls._name = cls.__module__ + '.' + cls.__qualname__ + if name != 'EosObject': + EosObject.objects[cls._name] = cls + return cls + +class EosObject(metaclass=EosObjectType): + objects = {} + + @staticmethod + def serialise_and_wrap(value, object_type=None): + if object_type: + return value.serialise() + return {'type': value._name, 'value': value.serialise()} + + @staticmethod + def deserialise_and_unwrap(value, object_type=None): + if object_type: + return object_type.deserialise(value) + return EosObject.objects[value['type']].deserialise(value['value']) + +class DocumentObjectType(EosObjectType): + def __new__(meta, name, bases, attrs): + cls = EosObjectType.__new__(meta, name, bases, attrs) # Process fields - fields = cls._fields if hasattr(cls, '_fields') else {} + fields = cls._fields.copy() if hasattr(cls, '_fields') else {} # remember to .copy() XD for attr in list(dir(cls)): val = getattr(cls, attr) if isinstance(val, Field): @@ -95,14 +127,29 @@ class EosObjectType(type): delattr(cls, attr) cls._fields = fields - cls._name = cls.__module__ + '.' + cls.__qualname__ - return cls -class EosObject(metaclass=EosObjectType): +class DocumentObject(metaclass=DocumentObjectType): def __init__(self, *args, **kwargs): for attr, val in self._fields.items(): - setattr(self, attr, kwargs.get(attr, val.default)) + if attr in kwargs: + setattr(self, attr, kwargs[attr]) + else: + default = val.default + if callable(default): + default = default() + setattr(self, attr, default) + + def serialise(self): + return {attr: val.serialise(getattr(self, attr)) for attr, val in self._fields.items()} + + @classmethod + def deserialise(cls, value): + return cls(**value) # wew -TopLevelObject = EosObject -EmbeddedObject = EosObject +class TopLevelObject(DocumentObject): + def save(self): + res = db[self._name].replace_one({'_id': self.serialise()['_id']}, self.serialise(), upsert=True) + +class EmbeddedObject(DocumentObject): + pass diff --git a/eos/core/tests.py b/eos/core/tests.py index 14a6f4a..9bbde91 100644 --- a/eos/core/tests.py +++ b/eos/core/tests.py @@ -19,14 +19,15 @@ from unittest import TestCase from eos.core.objects import * class PyTestCase(TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): class Person(TopLevelObject): name = StringField() address = StringField(default=None) def say_hi(self): return 'Hello! My name is ' + self.name - self.Person = Person + cls.Person = Person def test_basic_py(self): person1 = self.Person(name='John', address='Address 1') @@ -36,3 +37,11 @@ class PyTestCase(TestCase): self.assertEqual(person2.address, 'Address 2') self.assertEqual(person1.say_hi(), 'Hello! My name is John') self.assertEqual(person2.say_hi(), 'Hello! My name is James') + + def test_serialise_py(self): + person1 = self.Person(name='John', address='Address 1') + expect1 = {'name': 'John', 'address': 'Address 1'} + + self.assertEqual(person1.serialise(), expect1) + self.assertEqual(EosObject.serialise_and_wrap(person1, self.Person), expect1) + self.assertEqual(EosObject.serialise_and_wrap(person1), {'type': 'eos.core.tests.PyTestCase.setUpClass..Person', 'value': expect1}) diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..4574fed --- /dev/null +++ b/test.sh @@ -0,0 +1,6 @@ +#!/bin/bash +ARGS=-vvv + +for test in eos.core.tests eos.base.tests; do + python -m unittest $test $ARGS || exit 1 +done