Store votes in a separate collection for better concurrency support

This commit is contained in:
RunasSudo 2017-12-15 20:21:57 +10:30 committed by Yingtong Li
parent 4d9ad7226a
commit 9255899a01
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
9 changed files with 47 additions and 32 deletions

View File

@ -46,8 +46,11 @@ class Ballot(EmbeddedObject):
return Ballot(encrypted_answers=encrypted_answers_deaudit, election_id=self.election_id, election_hash=self.election_hash) return Ballot(encrypted_answers=encrypted_answers_deaudit, election_id=self.election_id, election_hash=self.election_hash)
class Vote(EmbeddedObject): class Vote(TopLevelObject):
_ver = StringField(default='0.5') _ver = StringField(default='0.6')
_id = UUIDField()
voter_id = UUIDField()
ballot = EmbeddedObjectField() ballot = EmbeddedObjectField()
cast_at = DateTimeField() cast_at = DateTimeField()
@ -57,8 +60,10 @@ class Vote(EmbeddedObject):
cast_fingerprint = BlobField(is_protected=True) cast_fingerprint = BlobField(is_protected=True)
class Voter(EmbeddedObject): class Voter(EmbeddedObject):
_ver = StringField(default='0.6')
_id = UUIDField() _id = UUIDField()
votes = EmbeddedObjectListField() votes = RelatedObjectListField(related_type=Vote, object_type=None, this_field='_id', related_field='voter_id')
class User(EmbeddedObject): class User(EmbeddedObject):
admins = [] admins = []
@ -229,7 +234,7 @@ class Election(TopLevelObject):
election_hash = SHA256().update_obj(self).hash_as_b64() election_hash = SHA256().update_obj(self).hash_as_b64()
for voter in self.voters: for voter in self.voters:
for vote in voter.votes: for vote in voter.votes.get_all():
if vote.ballot.election_id != self._id: if vote.ballot.election_id != self._id:
raise Exception('Invalid election ID on ballot') raise Exception('Invalid election ID on ballot')
if vote.ballot.election_hash != election_hash: if vote.ballot.election_hash != election_hash:

View File

@ -95,10 +95,10 @@ class ElectionTestCase(EosTestCase):
answer = ApprovalAnswer(choices=VOTES[i][j]) answer = ApprovalAnswer(choices=VOTES[i][j])
encrypted_answer = NullEncryptedAnswer(answer=answer) encrypted_answer = NullEncryptedAnswer(answer=answer)
ballot.encrypted_answers.append(encrypted_answer) ballot.encrypted_answers.append(encrypted_answer)
vote = Vote(ballot=ballot, cast_at=DateTimeField.now()) vote = Vote(voter_id=election.voters[i]._id, ballot=ballot, cast_at=DateTimeField.now())
election.voters[i].votes.append(vote) vote.save()
election.save() #election.save()
# Close voting # Close voting
self.do_task_assert(election, 'eos.base.workflow.TaskCloseVoting', 'eos.base.workflow.TaskDecryptVotes') self.do_task_assert(election, 'eos.base.workflow.TaskCloseVoting', 'eos.base.workflow.TaskDecryptVotes')

View File

@ -167,8 +167,8 @@ class TaskDecryptVotes(WorkflowTask):
election.results.append(EosObject.lookup('eos.base.election.RawResult')()) election.results.append(EosObject.lookup('eos.base.election.RawResult')())
for voter in election.voters: for voter in election.voters:
if len(voter.votes) > 0: if len(voter.votes.get_all()) > 0:
vote = voter.votes[-1] vote = voter.votes.get_all()[-1]
ballot = vote.ballot ballot = vote.ballot
for q_num in range(len(ballot.encrypted_answers)): for q_num in range(len(ballot.encrypted_answers)):
plaintexts, answer = ballot.encrypted_answers[q_num].decrypt() plaintexts, answer = ballot.encrypted_answers[q_num].decrypt()

View File

@ -33,7 +33,7 @@ class MongoDBProvider(eos.core.db.DBProvider):
if 'type' in fields: if 'type' in fields:
query['type'] = fields.pop('type') query['type'] = fields.pop('type')
for field in fields: for field in fields:
query['value.' + field] = fields.pop(field) query['value.' + field] = fields[field]
return self.db[collection].find(query) return self.db[collection].find(query)
def get_by_id(self, collection, _id): def get_by_id(self, collection, _id):

View File

@ -85,6 +85,9 @@ class Field:
value._instance = (obj, self.real_name) value._instance = (obj, self.real_name)
if not value._inited: if not value._inited:
value.post_init() value.post_init()
def object_init(self, obj, value):
self.object_set(obj, value)
class SerialiseOptions: class SerialiseOptions:
def __init__(self, for_hash=False, should_protect=False, combine_related=False): def __init__(self, for_hash=False, should_protect=False, combine_related=False):
@ -152,14 +155,15 @@ class RelatedObjectListManager:
def get_all(self): def get_all(self):
query = {self.field.related_field: getattr(self.obj, self.field.this_field)} query = {self.field.related_field: getattr(self.obj, self.field.this_field)}
return self.field.object_type.get_all_by_fields(**query) return self.field.related_type.get_all_by_fields(**query)
class RelatedObjectListField(Field): class RelatedObjectListField(Field):
def __init__(self, object_type=None, *args, **kwargs): def __init__(self, object_type=None, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.object_type = object_type self.related_type = kwargs['related_type']
self.this_field = args['this_field'] if 'this_field' in args else '_id' self.object_type = kwargs['object_type'] if 'object_type' in kwargs else None
self.related_field = args['related_field'] if 'related_field' in args else 'related_id' self.this_field = kwargs['this_field'] if 'this_field' in kwargs else '_id'
self.related_field = kwargs['related_field']
def object_get(self, obj): def object_get(self, obj):
return RelatedObjectListManager(self, obj) return RelatedObjectListManager(self, obj)
@ -167,6 +171,9 @@ class RelatedObjectListField(Field):
def object_set(self, obj, value): def object_set(self, obj, value):
raise Exception('Cannot directly set related field') raise Exception('Cannot directly set related field')
def object_init(self, obj, value):
pass
def serialise(self, value, options=SerialiseOptions.DEFAULT): def serialise(self, value, options=SerialiseOptions.DEFAULT):
if not options.combine_related: if not options.combine_related:
return None return None
@ -174,7 +181,7 @@ class RelatedObjectListField(Field):
def deserialise(self, value): def deserialise(self, value):
if value is None: if value is None:
return self.get_manager() return None
return EosList([EosObject.deserialise_and_unwrap(x, self.object_type) for x in value]) return EosList([EosObject.deserialise_and_unwrap(x, self.object_type) for x in value])
if is_python: if is_python:
@ -270,8 +277,9 @@ class EosObject(metaclass=EosObjectType):
if object_type: if object_type:
if value: if value:
return value.serialise(options) return value.serialise(options)
return None if value:
return {'type': value._name, 'value': (value.serialise(options) if value else None)} return {'type': value._name, 'value': (value.serialise(options) if value else None)}
return None
@staticmethod @staticmethod
def deserialise_and_unwrap(value, object_type=None): def deserialise_and_unwrap(value, object_type=None):
@ -428,12 +436,12 @@ class DocumentObject(EosObject, metaclass=DocumentObjectType):
}) })
if val.internal_name in kwargs: if val.internal_name in kwargs:
setattr(self, val.real_name, kwargs[val.internal_name]) val.object_init(self, kwargs[val.internal_name])
else: else:
default = val.default default = val.default
if default is not None and callable(default): if default is not None and callable(default):
default = default() default = default()
setattr(self, val.real_name, default) val.object_init(self, default)
def serialise(self, options=SerialiseOptions.DEFAULT): def serialise(self, options=SerialiseOptions.DEFAULT):
return {val.real_name: val.serialise(getattr(self, val.real_name), options) for attr, val in self._fields.items() if ((val.is_hashed or not options.for_hash) and (not options.should_protect or not val.is_protected))} return {val.real_name: val.serialise(getattr(self, val.real_name), options) for attr, val in self._fields.items() if ((val.is_hashed or not options.for_hash) and (not options.should_protect or not val.is_protected))}
@ -479,6 +487,9 @@ class TopLevelObject(DocumentObject, metaclass=TopLevelObjectType):
@classmethod @classmethod
def get_all_by_fields(cls, **fields): def get_all_by_fields(cls, **fields):
for field in fields:
if not isinstance(fields[field], str):
fields[field] = str(fields[field])
return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all_by_fields(cls._db_name, fields)] return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all_by_fields(cls._db_name, fields)]
@classmethod @classmethod

View File

@ -96,8 +96,8 @@ class MixingTrustee(Trustee):
# Use the raw ballots from voters # Use the raw ballots from voters
orig_answers = [] orig_answers = []
for voter in self.recurse_parents(Election).voters: for voter in self.recurse_parents(Election).voters:
if len(voter.votes) > 0: if len(voter.votes.get_all()) > 0:
vote = voter.votes[-1] vote = voter.votes.get_all()[-1]
ballot = vote.ballot ballot = vote.ballot
orig_answers.append(ballot.encrypted_answers[question_num]) orig_answers.append(ballot.encrypted_answers[question_num])
return orig_answers return orig_answers
@ -195,8 +195,8 @@ class InternalMixingTrustee(MixingTrustee):
else: else:
orig_answers = [] orig_answers = []
for voter in election.voters: for voter in election.voters:
if len(voter.votes) > 0: if len(voter.votes.get_all()) > 0:
ballot = voter.votes[-1].ballot ballot = voter.votes.get_all()[-1].ballot
orig_answers.append(ballot.encrypted_answers[question]) orig_answers.append(ballot.encrypted_answers[question])
shuffled_answers, commitments = self.mixnets[question].shuffle(orig_answers) shuffled_answers, commitments = self.mixnets[question].shuffle(orig_answers)
self.mixed_questions.append(EosList(shuffled_answers)) self.mixed_questions.append(EosList(shuffled_answers))

View File

@ -278,10 +278,10 @@ class ElectionTestCase(EosTestCase):
answer = ApprovalAnswer(choices=VOTES[i][j]) answer = ApprovalAnswer(choices=VOTES[i][j])
encrypted_answer = BlockEncryptedAnswer.encrypt(election.sk.public_key, answer) encrypted_answer = BlockEncryptedAnswer.encrypt(election.sk.public_key, answer)
ballot.encrypted_answers.append(encrypted_answer) ballot.encrypted_answers.append(encrypted_answer)
vote = Vote(ballot=ballot, cast_at=DateTimeField.now()) vote = Vote(voter_id=election.voters[i]._id, ballot=ballot, cast_at=DateTimeField.now())
election.voters[i].votes.append(vote) vote.save()
election.save() #election.save()
# Close voting # Close voting
self.do_task_assert(election, 'eos.base.workflow.TaskCloseVoting', 'eos.psr.workflow.TaskMixVotes') self.do_task_assert(election, 'eos.base.workflow.TaskCloseVoting', 'eos.psr.workflow.TaskMixVotes')

View File

@ -311,7 +311,7 @@ def election_api_cast_vote(election):
# Cast the vote # Cast the vote
ballot = EosObject.deserialise_and_unwrap(data['ballot']) ballot = EosObject.deserialise_and_unwrap(data['ballot'])
vote = Vote(ballot=ballot, cast_at=DateTimeField.now()) vote = Vote(voter_id=voter._id, ballot=ballot, cast_at=DateTimeField.now())
# Store data # Store data
if app.config['CAST_FINGERPRINT']: if app.config['CAST_FINGERPRINT']:
@ -322,9 +322,7 @@ def election_api_cast_vote(election):
else: else:
vote.cast_ip = flask.request.remote_addr vote.cast_ip = flask.request.remote_addr
voter.votes.append(vote) vote.save()
election.save()
return flask.Response(json.dumps({ return flask.Response(json.dumps({
'voter': EosObject.serialise_and_wrap(voter, None, SerialiseOptions(should_protect=True)), 'voter': EosObject.serialise_and_wrap(voter, None, SerialiseOptions(should_protect=True)),

View File

@ -30,8 +30,9 @@
{% for voter in election.voters %} {% for voter in election.voters %}
<tr> <tr>
<td>{{ voter.name }}</td> <td>{{ voter.name }}</td>
{% if voter.votes|length > 0 %} {% set votes = voter.votes.get_all() %}
<td class="hash">{{ SHA256().update_obj(voter.votes[-1].ballot).hash_as_b64() }}</td> {% if votes|length > 0 %}
<td class="hash">{{ SHA256().update_obj(votes[-1].ballot).hash_as_b64() }}</td>
{% else %} {% else %}
<td class="hash"></td> <td class="hash"></td>
{% endif %} {% endif %}