From 4d9ad7226a578bf0e1d8294329a7299194639955 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Fri, 15 Dec 2017 19:51:03 +1030 Subject: [PATCH] And more abstraction to fields and prepare for related fields --- eos/core/bigint/js.py | 4 +- eos/core/bigint/python.py | 4 +- eos/core/db/__init__.py | 6 +++ eos/core/db/mongodb.py | 10 ++++ eos/core/db/postgresql.py | 14 +++++ eos/core/hashing/__init__.py | 4 +- eos/core/objects/__init__.py | 102 +++++++++++++++++++++++++---------- eos/psr/bitstream.py | 4 +- eosweb/core/main.py | 7 +-- 9 files changed, 116 insertions(+), 39 deletions(-) diff --git a/eos/core/bigint/js.py b/eos/core/bigint/js.py index 7ca157d..02244fc 100644 --- a/eos/core/bigint/js.py +++ b/eos/core/bigint/js.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from eos.core.objects import EosObject +from eos.core.objects import * import random @@ -125,7 +125,7 @@ class BigInt(EosObject): def nbits(self): return self.impl.bitLength() - def serialise(self, for_hash=False, should_protect=False): + def serialise(self, options=SerialiseOptions.DEFAULT): return str(self) @classmethod diff --git a/eos/core/bigint/python.py b/eos/core/bigint/python.py index 268dca1..76ed567 100644 --- a/eos/core/bigint/python.py +++ b/eos/core/bigint/python.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from eos.core.objects import EosObject +from eos.core.objects import * import math @@ -46,7 +46,7 @@ class BigInt(EosObject): def nbits(self): return math.ceil(math.log2(self.impl)) if self.impl > 0 else 0 - def serialise(self, for_hash=False, should_protect=False): + def serialise(self, options=SerialiseOptions.DEFAULT): return str(self) @classmethod diff --git a/eos/core/db/__init__.py b/eos/core/db/__init__.py index 58272ba..05504db 100644 --- a/eos/core/db/__init__.py +++ b/eos/core/db/__init__.py @@ -27,6 +27,9 @@ class DBProvider: def get_all(self, collection): raise Exception('Not implemented') + def get_all_by_fields(self, collection, fields): + raise Exception('Not implemented') + def get_by_id(self, collection, _id): raise Exception('Not implemented') @@ -43,6 +46,9 @@ class DummyProvider(DBProvider): def get_all(self, collection): pass + def get_all_by_fields(self, collection, fields): + pass + def get_by_id(self, collection, _id): pass diff --git a/eos/core/db/mongodb.py b/eos/core/db/mongodb.py index 14bfed7..a52437f 100644 --- a/eos/core/db/mongodb.py +++ b/eos/core/db/mongodb.py @@ -26,6 +26,16 @@ class MongoDBProvider(eos.core.db.DBProvider): def get_all(self, collection): return self.db[collection].find() + def get_all_by_fields(self, collection, fields): + query = {} + if '_id' in fields: + query['_id'] = fields.pop('_id') + if 'type' in fields: + query['type'] = fields.pop('type') + for field in fields: + query['value.' + field] = fields.pop(field) + return self.db[collection].find(query) + def get_by_id(self, collection, _id): return self.db[collection].find_one(_id) diff --git a/eos/core/db/postgresql.py b/eos/core/db/postgresql.py index 1512c81..d5668e0 100644 --- a/eos/core/db/postgresql.py +++ b/eos/core/db/postgresql.py @@ -34,6 +34,20 @@ class PostgreSQLDBProvider(eos.core.db.DBProvider): self.cur.execute(SQL('SELECT data FROM {}').format(Identifier(table))) return [x[0] for x in self.cur.fetchall()] + def get_all_by_fields(self, table, fields): + # TODO: Make this much better + result = [] + for val in self.get_all(table): + if '_id' in fields and val['_id'] != fields.pop('_id'): + continue + if 'type' in fields and val['type'] != fields.pop('type'): + continue + for field in fields: + if val['value'][field] != fields[field]: + continue + result.append(val) + return result + def get_by_id(self, table, _id): self.create_table(table) self.cur.execute(SQL('SELECT data FROM {} WHERE _id = %s').format(Identifier(table)), (_id,)) diff --git a/eos/core/hashing/__init__.py b/eos/core/hashing/__init__.py index beb8e17..ea378ab 100644 --- a/eos/core/hashing/__init__.py +++ b/eos/core/hashing/__init__.py @@ -71,12 +71,12 @@ class SHA256: def update_obj(self, *values): for value in values: - self.update_text(EosObject.to_json(EosObject.serialise_and_wrap(value, None, True))) + self.update_text(EosObject.to_json(EosObject.serialise_and_wrap(value, None, SerialiseOptions(for_hash=True)))) return self def update_obj_raw(self, *values): for value in values: - self.update_text(EosObject.to_json(EosObject.serialise_and_wrap(value, None, False))) + self.update_text(EosObject.to_json(EosObject.serialise_and_wrap(value, None, SerialiseOptions(for_hash=False)))) return self def hash_as_b64(self): diff --git a/eos/core/objects/__init__.py b/eos/core/objects/__init__.py index f7caf72..6fc6122 100644 --- a/eos/core/objects/__init__.py +++ b/eos/core/objects/__init__.py @@ -74,9 +74,28 @@ class Field: self.default = kwargs['default'] if 'default' in kwargs else kwargs['py_default'] if 'py_default' in kwargs else None self.is_protected = kwargs['is_protected'] if 'is_protected' in kwargs else False self.is_hashed = kwargs['is_hashed'] if 'is_hashed' in kwargs else not self.is_protected + + def object_get(self, obj): + return obj._field_values[self.real_name] + + def object_set(self, obj, value): + obj._field_values[self.real_name] = value + + if isinstance(value, EosObject): + value._instance = (obj, self.real_name) + if not value._inited: + value.post_init() + +class SerialiseOptions: + def __init__(self, for_hash=False, should_protect=False, combine_related=False): + self.for_hash = for_hash + self.should_protect = should_protect + self.combine_related = combine_related + +SerialiseOptions.DEFAULT = SerialiseOptions() class PrimitiveField(Field): - def serialise(self, value, for_hash=False, should_protect=False): + def serialise(self, value, options=SerialiseOptions.DEFAULT): return value def deserialise(self, value): @@ -93,8 +112,8 @@ class EmbeddedObjectField(Field): super().__init__(*args, **kwargs) self.object_type = object_type - def serialise(self, value, for_hash=False, should_protect=False): - return EosObject.serialise_and_wrap(value, self.object_type, for_hash, should_protect) + def serialise(self, value, options=SerialiseOptions.DEFAULT): + return EosObject.serialise_and_wrap(value, self.object_type, options) def deserialise(self, value): return EosObject.deserialise_and_unwrap(value, self.object_type) @@ -104,8 +123,8 @@ class ListField(Field): super().__init__(default=EosList, *args, **kwargs) self.element_field = element_field - def serialise(self, value, for_hash=False, should_protect=False): - return [self.element_field.serialise(x, for_hash, should_protect) for x in (value.impl if isinstance(value, EosList) else value)] + def serialise(self, value, options=SerialiseOptions.DEFAULT): + return [self.element_field.serialise(x, options) for x in (value.impl if isinstance(value, EosList) else value)] def deserialise(self, value): return EosList([self.element_field.deserialise(x) for x in value]) @@ -115,23 +134,55 @@ class EmbeddedObjectListField(Field): super().__init__(default=EosList, *args, **kwargs) self.object_type = object_type - def serialise(self, value, for_hash=False, should_protect=False): + def serialise(self, value, options=SerialiseOptions.DEFAULT): # TNYI: Doesn't know how to deal with iterators like EosList if value is None: return None - return [EosObject.serialise_and_wrap(x, self.object_type, for_hash, should_protect) for x in (value.impl if isinstance(value, EosList) else value)] + return [EosObject.serialise_and_wrap(x, self.object_type, options) for x in (value.impl if isinstance(value, EosList) else value)] def deserialise(self, value): if value is None: return None return EosList([EosObject.deserialise_and_unwrap(x, self.object_type) for x in value]) +class RelatedObjectListManager: + def __init__(self, field, obj): + self.field = field + self.obj = obj + + def get_all(self): + query = {self.field.related_field: getattr(self.obj, self.field.this_field)} + return self.field.object_type.get_all_by_fields(**query) + +class RelatedObjectListField(Field): + def __init__(self, object_type=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.object_type = object_type + self.this_field = args['this_field'] if 'this_field' in args else '_id' + self.related_field = args['related_field'] if 'related_field' in args else 'related_id' + + def object_get(self, obj): + return RelatedObjectListManager(self, obj) + + def object_set(self, obj, value): + raise Exception('Cannot directly set related field') + + def serialise(self, value, options=SerialiseOptions.DEFAULT): + if not options.combine_related: + return None + return EmbeddedObjectListField(object_type=self.object_type).serialise(value.get_all(), options) + + def deserialise(self, value): + if value is None: + return self.get_manager() + return EosList([EosObject.deserialise_and_unwrap(x, self.object_type) for x in value]) + if is_python: class UUIDField(Field): def __init__(self, *args, **kwargs): super().__init__(default=uuid.uuid4, *args, **kwargs) - def serialise(self, value, for_hash=False, should_protect=False): + def serialise(self, value, options=SerialiseOptions.DEFAULT): return str(value) def deserialise(self, value): @@ -145,7 +196,7 @@ class DateTimeField(Field): return '0' + str(number) return str(number) - def serialise(self, value, for_hash=False, should_protect=False): + def serialise(self, value, options=SerialiseOptions.DEFAULT): if value is None: return None @@ -215,12 +266,12 @@ class EosObject(metaclass=EosObjectType): return EosObject.objects[name] @staticmethod - def serialise_and_wrap(value, object_type=None, for_hash=False, should_protect=False): + def serialise_and_wrap(value, object_type=None, options=SerialiseOptions.DEFAULT): if object_type: if value: - return value.serialise(for_hash, should_protect) + return value.serialise(options) return None - return {'type': value._name, 'value': (value.serialise(for_hash, should_protect) if value else None)} + return {'type': value._name, 'value': (value.serialise(options) if value else None)} @staticmethod def deserialise_and_unwrap(value, object_type=None): @@ -327,14 +378,9 @@ class DocumentObjectType(EosObjectType): if is_python: def make_property(name, field): def field_getter(self): - return self._field_values[name] + return field.object_get(self) def field_setter(self, value): - self._field_values[name] = value - - if isinstance(value, EosObject): - value._instance = (self, name) - if not value._inited: - value.post_init() + field.object_set(self, value) return property(field_getter, field_setter) for attr, val in fields.items(): @@ -362,15 +408,11 @@ class DocumentObject(EosObject, metaclass=DocumentObjectType): pass else: def make_property(name, field): + # TNYI: Transcrypt doesn't pass self def field_getter(): - return self._field_values[name] + return field.object_get(self) def field_setter(value): - self._field_values[name] = value - - if isinstance(value, EosObject): - value._instance = (self, name) - if not value._inited: - value.post_init() + field.object_set(self, value) return (field_getter, field_setter) prop = make_property(val.real_name, val) # TNYI: No support for property() @@ -393,8 +435,8 @@ class DocumentObject(EosObject, metaclass=DocumentObjectType): default = default() setattr(self, val.real_name, default) - def serialise(self, for_hash=False, should_protect=False): - return {val.real_name: val.serialise(getattr(self, val.real_name), for_hash, should_protect) for attr, val in self._fields.items() if ((val.is_hashed or not for_hash) and (not should_protect or not val.is_protected))} + def serialise(self, options=SerialiseOptions.DEFAULT): + return {val.real_name: val.serialise(getattr(self, val.real_name), options) for attr, val in self._fields.items() if ((val.is_hashed or not options.for_hash) and (not options.should_protect or not val.is_protected))} @classmethod def deserialise(cls, value): @@ -435,6 +477,10 @@ class TopLevelObject(DocumentObject, metaclass=TopLevelObjectType): def get_all(cls): return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all(cls._db_name)] + @classmethod + def get_all_by_fields(cls, **fields): + return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all_by_fields(cls._db_name, fields)] + @classmethod def get_by_id(cls, _id): if not isinstance(_id, str): diff --git a/eos/psr/bitstream.py b/eos/psr/bitstream.py index daff402..9c0f5de 100644 --- a/eos/psr/bitstream.py +++ b/eos/psr/bitstream.py @@ -153,7 +153,7 @@ class BitStream(EosObject): bs.seek(0) return bs - def serialise(self): + def serialise(self, options=SerialiseOptions.DEFAULT): return self.impl @classmethod @@ -173,7 +173,7 @@ class InfiniteHashBitStream(BitStream): # 11000110110 # ^---- if nbits is None: - nbits = self.remaining + raise Exception('Cannot read indefinite amount from InfiniteHashBitStream') while nbits > self.remaining: self.ctr += 1 self.sha.update_text(str(self.ctr)) diff --git a/eosweb/core/main.py b/eosweb/core/main.py index 36bbd3d..e3bf6b9 100644 --- a/eosweb/core/main.py +++ b/eosweb/core/main.py @@ -214,7 +214,8 @@ def election_admin(func): @app.route('/election//') @using_election def election_api_json(election): - return flask.Response(EosObject.to_json(EosObject.serialise_and_wrap(election, should_protect=True, for_hash=('full' not in flask.request.args))), mimetype='application/json') + is_full = 'full' in flask.request.args + return flask.Response(EosObject.to_json(EosObject.serialise_and_wrap(election, None, SerialiseOptions(should_protect=True, for_hash=(not is_full), combine_related=True))), mimetype='application/json') @app.route('/election//view') @using_election @@ -326,8 +327,8 @@ def election_api_cast_vote(election): election.save() return flask.Response(json.dumps({ - 'voter': EosObject.serialise_and_wrap(voter, should_protect=True), - 'vote': EosObject.serialise_and_wrap(vote, should_protect=True) + 'voter': EosObject.serialise_and_wrap(voter, None, SerialiseOptions(should_protect=True)), + 'vote': EosObject.serialise_and_wrap(vote, None, SerialiseOptions(should_protect=True)) }), mimetype='application/json') @app.route('/election//export/question//')