From 9255899a0146fcb697f1e2edd48e721f05af5cf3 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Fri, 15 Dec 2017 20:21:57 +1030 Subject: [PATCH] Store votes in a separate collection for better concurrency support --- eos/base/election.py | 13 ++++++--- eos/base/tests.py | 6 ++-- eos/base/workflow.py | 4 +-- eos/core/db/mongodb.py | 2 +- eos/core/objects/__init__.py | 29 +++++++++++++------ eos/psr/election.py | 8 ++--- eos/psr/tests.py | 6 ++-- eosweb/core/main.py | 6 ++-- .../core/templates/election/view/ballots.html | 5 ++-- 9 files changed, 47 insertions(+), 32 deletions(-) diff --git a/eos/base/election.py b/eos/base/election.py index cb56e4c..ad730d5 100644 --- a/eos/base/election.py +++ b/eos/base/election.py @@ -46,8 +46,11 @@ class Ballot(EmbeddedObject): return Ballot(encrypted_answers=encrypted_answers_deaudit, election_id=self.election_id, election_hash=self.election_hash) -class Vote(EmbeddedObject): - _ver = StringField(default='0.5') +class Vote(TopLevelObject): + _ver = StringField(default='0.6') + + _id = UUIDField() + voter_id = UUIDField() ballot = EmbeddedObjectField() cast_at = DateTimeField() @@ -57,8 +60,10 @@ class Vote(EmbeddedObject): cast_fingerprint = BlobField(is_protected=True) class Voter(EmbeddedObject): + _ver = StringField(default='0.6') + _id = UUIDField() - votes = EmbeddedObjectListField() + votes = RelatedObjectListField(related_type=Vote, object_type=None, this_field='_id', related_field='voter_id') class User(EmbeddedObject): admins = [] @@ -229,7 +234,7 @@ class Election(TopLevelObject): election_hash = SHA256().update_obj(self).hash_as_b64() for voter in self.voters: - for vote in voter.votes: + for vote in voter.votes.get_all(): if vote.ballot.election_id != self._id: raise Exception('Invalid election ID on ballot') if vote.ballot.election_hash != election_hash: diff --git a/eos/base/tests.py b/eos/base/tests.py index 78abbcb..6be2bb4 100644 --- a/eos/base/tests.py +++ b/eos/base/tests.py @@ -95,10 +95,10 @@ class ElectionTestCase(EosTestCase): answer = ApprovalAnswer(choices=VOTES[i][j]) encrypted_answer = NullEncryptedAnswer(answer=answer) ballot.encrypted_answers.append(encrypted_answer) - vote = Vote(ballot=ballot, cast_at=DateTimeField.now()) - election.voters[i].votes.append(vote) + vote = Vote(voter_id=election.voters[i]._id, ballot=ballot, cast_at=DateTimeField.now()) + vote.save() - election.save() + #election.save() # Close voting self.do_task_assert(election, 'eos.base.workflow.TaskCloseVoting', 'eos.base.workflow.TaskDecryptVotes') diff --git a/eos/base/workflow.py b/eos/base/workflow.py index d934bbf..7ddd10e 100644 --- a/eos/base/workflow.py +++ b/eos/base/workflow.py @@ -167,8 +167,8 @@ class TaskDecryptVotes(WorkflowTask): election.results.append(EosObject.lookup('eos.base.election.RawResult')()) for voter in election.voters: - if len(voter.votes) > 0: - vote = voter.votes[-1] + if len(voter.votes.get_all()) > 0: + vote = voter.votes.get_all()[-1] ballot = vote.ballot for q_num in range(len(ballot.encrypted_answers)): plaintexts, answer = ballot.encrypted_answers[q_num].decrypt() diff --git a/eos/core/db/mongodb.py b/eos/core/db/mongodb.py index a52437f..73702ee 100644 --- a/eos/core/db/mongodb.py +++ b/eos/core/db/mongodb.py @@ -33,7 +33,7 @@ class MongoDBProvider(eos.core.db.DBProvider): if 'type' in fields: query['type'] = fields.pop('type') for field in fields: - query['value.' + field] = fields.pop(field) + query['value.' + field] = fields[field] return self.db[collection].find(query) def get_by_id(self, collection, _id): diff --git a/eos/core/objects/__init__.py b/eos/core/objects/__init__.py index 6fc6122..885eaed 100644 --- a/eos/core/objects/__init__.py +++ b/eos/core/objects/__init__.py @@ -85,6 +85,9 @@ class Field: value._instance = (obj, self.real_name) if not value._inited: value.post_init() + + def object_init(self, obj, value): + self.object_set(obj, value) class SerialiseOptions: def __init__(self, for_hash=False, should_protect=False, combine_related=False): @@ -152,14 +155,15 @@ class RelatedObjectListManager: def get_all(self): query = {self.field.related_field: getattr(self.obj, self.field.this_field)} - return self.field.object_type.get_all_by_fields(**query) + return self.field.related_type.get_all_by_fields(**query) class RelatedObjectListField(Field): def __init__(self, object_type=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.object_type = object_type - self.this_field = args['this_field'] if 'this_field' in args else '_id' - self.related_field = args['related_field'] if 'related_field' in args else 'related_id' + self.related_type = kwargs['related_type'] + self.object_type = kwargs['object_type'] if 'object_type' in kwargs else None + self.this_field = kwargs['this_field'] if 'this_field' in kwargs else '_id' + self.related_field = kwargs['related_field'] def object_get(self, obj): return RelatedObjectListManager(self, obj) @@ -167,6 +171,9 @@ class RelatedObjectListField(Field): def object_set(self, obj, value): raise Exception('Cannot directly set related field') + def object_init(self, obj, value): + pass + def serialise(self, value, options=SerialiseOptions.DEFAULT): if not options.combine_related: return None @@ -174,7 +181,7 @@ class RelatedObjectListField(Field): def deserialise(self, value): if value is None: - return self.get_manager() + return None return EosList([EosObject.deserialise_and_unwrap(x, self.object_type) for x in value]) if is_python: @@ -270,8 +277,9 @@ class EosObject(metaclass=EosObjectType): if object_type: if value: return value.serialise(options) - return None - return {'type': value._name, 'value': (value.serialise(options) if value else None)} + if value: + return {'type': value._name, 'value': (value.serialise(options) if value else None)} + return None @staticmethod def deserialise_and_unwrap(value, object_type=None): @@ -428,12 +436,12 @@ class DocumentObject(EosObject, metaclass=DocumentObjectType): }) if val.internal_name in kwargs: - setattr(self, val.real_name, kwargs[val.internal_name]) + val.object_init(self, kwargs[val.internal_name]) else: default = val.default if default is not None and callable(default): default = default() - setattr(self, val.real_name, default) + val.object_init(self, default) def serialise(self, options=SerialiseOptions.DEFAULT): return {val.real_name: val.serialise(getattr(self, val.real_name), options) for attr, val in self._fields.items() if ((val.is_hashed or not options.for_hash) and (not options.should_protect or not val.is_protected))} @@ -479,6 +487,9 @@ class TopLevelObject(DocumentObject, metaclass=TopLevelObjectType): @classmethod def get_all_by_fields(cls, **fields): + for field in fields: + if not isinstance(fields[field], str): + fields[field] = str(fields[field]) return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all_by_fields(cls._db_name, fields)] @classmethod diff --git a/eos/psr/election.py b/eos/psr/election.py index cb5e854..f54ef3c 100644 --- a/eos/psr/election.py +++ b/eos/psr/election.py @@ -96,8 +96,8 @@ class MixingTrustee(Trustee): # Use the raw ballots from voters orig_answers = [] for voter in self.recurse_parents(Election).voters: - if len(voter.votes) > 0: - vote = voter.votes[-1] + if len(voter.votes.get_all()) > 0: + vote = voter.votes.get_all()[-1] ballot = vote.ballot orig_answers.append(ballot.encrypted_answers[question_num]) return orig_answers @@ -195,8 +195,8 @@ class InternalMixingTrustee(MixingTrustee): else: orig_answers = [] for voter in election.voters: - if len(voter.votes) > 0: - ballot = voter.votes[-1].ballot + if len(voter.votes.get_all()) > 0: + ballot = voter.votes.get_all()[-1].ballot orig_answers.append(ballot.encrypted_answers[question]) shuffled_answers, commitments = self.mixnets[question].shuffle(orig_answers) self.mixed_questions.append(EosList(shuffled_answers)) diff --git a/eos/psr/tests.py b/eos/psr/tests.py index 8104f71..4892511 100644 --- a/eos/psr/tests.py +++ b/eos/psr/tests.py @@ -278,10 +278,10 @@ class ElectionTestCase(EosTestCase): answer = ApprovalAnswer(choices=VOTES[i][j]) encrypted_answer = BlockEncryptedAnswer.encrypt(election.sk.public_key, answer) ballot.encrypted_answers.append(encrypted_answer) - vote = Vote(ballot=ballot, cast_at=DateTimeField.now()) - election.voters[i].votes.append(vote) + vote = Vote(voter_id=election.voters[i]._id, ballot=ballot, cast_at=DateTimeField.now()) + vote.save() - election.save() + #election.save() # Close voting self.do_task_assert(election, 'eos.base.workflow.TaskCloseVoting', 'eos.psr.workflow.TaskMixVotes') diff --git a/eosweb/core/main.py b/eosweb/core/main.py index e3bf6b9..6ac9a62 100644 --- a/eosweb/core/main.py +++ b/eosweb/core/main.py @@ -311,7 +311,7 @@ def election_api_cast_vote(election): # Cast the vote ballot = EosObject.deserialise_and_unwrap(data['ballot']) - vote = Vote(ballot=ballot, cast_at=DateTimeField.now()) + vote = Vote(voter_id=voter._id, ballot=ballot, cast_at=DateTimeField.now()) # Store data if app.config['CAST_FINGERPRINT']: @@ -322,9 +322,7 @@ def election_api_cast_vote(election): else: vote.cast_ip = flask.request.remote_addr - voter.votes.append(vote) - - election.save() + vote.save() return flask.Response(json.dumps({ 'voter': EosObject.serialise_and_wrap(voter, None, SerialiseOptions(should_protect=True)), diff --git a/eosweb/core/templates/election/view/ballots.html b/eosweb/core/templates/election/view/ballots.html index 7fb4438..a331047 100644 --- a/eosweb/core/templates/election/view/ballots.html +++ b/eosweb/core/templates/election/view/ballots.html @@ -30,8 +30,9 @@ {% for voter in election.voters %} {{ voter.name }} - {% if voter.votes|length > 0 %} - {{ SHA256().update_obj(voter.votes[-1].ballot).hash_as_b64() }} + {% set votes = voter.votes.get_all() %} + {% if votes|length > 0 %} + {{ SHA256().update_obj(votes[-1].ballot).hash_as_b64() }} {% else %} {% endif %}