diff --git a/eos/base/tests.py b/eos/base/tests.py index 45e2dff..be6d865 100644 --- a/eos/base/tests.py +++ b/eos/base/tests.py @@ -23,8 +23,7 @@ from eos.core.objects import * class ElectionTestCase(EosTestCase): @classmethod def setUpClass(cls): - db_connect('test') - dbinfo.provider.reset_db() + cls.db_connect_and_reset() def do_task_assert(self, election, task, next_task): self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY) diff --git a/eos/core/db/__init__.py b/eos/core/db/__init__.py index 61c5249..58272ba 100644 --- a/eos/core/db/__init__.py +++ b/eos/core/db/__init__.py @@ -20,3 +20,36 @@ class DBProvider: def __init__(self, db_name, db_uri): self.db_name = db_name self.db_uri = db_uri + + def connect(self): + raise Exception('Not implemented') + + def get_all(self, collection): + raise Exception('Not implemented') + + def get_by_id(self, collection, _id): + raise Exception('Not implemented') + + def update_by_id(self, collection, _id, value): + raise Exception('Not implemented') + + def reset_db(self): + raise Exception('Not implemented') + +class DummyProvider(DBProvider): + def connect(self): + pass + + def get_all(self, collection): + pass + + def get_by_id(self, collection, _id): + pass + + def update_by_id(self, collection, _id, value): + pass + + def reset_db(self): + pass + +db_providers['dummy'] = DummyProvider diff --git a/eos/core/objects/__init__.py b/eos/core/objects/__init__.py index 053c0b9..8c7ed9a 100644 --- a/eos/core/objects/__init__.py +++ b/eos/core/objects/__init__.py @@ -25,9 +25,10 @@ except: # Libraries # ========= +import eos.core.db + if is_python: __pragma__('skip') - import eos.core.db import eos.core.db.mongodb import eos.core.db.postgresql @@ -54,7 +55,7 @@ else: class DBInfo: def __init__(self): - self.provider = None + self.provider = eos.core.db.DummyProvider(None, None) dbinfo = DBInfo() diff --git a/eos/core/tasks/__init__.py b/eos/core/tasks/__init__.py new file mode 100644 index 0000000..a8c5522 --- /dev/null +++ b/eos/core/tasks/__init__.py @@ -0,0 +1,63 @@ +# Eos - Verifiable elections +# Copyright © 2017 RunasSudo (Yingtong Li) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from eos.core.objects import * + +class Task(TopLevelObject): + class Status: + UNKNOWN = 0 + + READY = 20 + PROCESSING = 30 + COMPLETE = 50 + + FAILED = -10 + TIMEOUT = -20 + + _id = UUIDField() + status = IntField(default=0) + run_strategy = EmbeddedObjectField() + messages = ListField(StringField()) + + def run(self): + self.run_strategy.run(self) + + def _run(self): + pass + +class RunStrategy(DocumentObject): + def run(self, task): + raise Exception('Not implemented') + +class DirectRunStrategy(RunStrategy): + def run(self, task): + task.status = Task.Status.PROCESSING + task.save() + + try: + task._run() + task.status = Task.Status.COMPLETE + task.save() + except Exception as e: + task.status = Task.Status.FAILED + if is_python: + #__pragma__('skip') + import traceback + #__pragma__('noskip') + task.messages.append(traceback.format_exc()) + else: + task.messages.append(repr(e)) + task.save() diff --git a/eos/core/tasks/direct.py b/eos/core/tasks/direct.py new file mode 100644 index 0000000..f5138fa --- /dev/null +++ b/eos/core/tasks/direct.py @@ -0,0 +1,19 @@ +# Eos - Verifiable elections +# Copyright © 2017 RunasSudo (Yingtong Li) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from eos.core.tasks import * + + diff --git a/eos/core/tests.py b/eos/core/tests.py index 93b623d..a8ed751 100644 --- a/eos/core/tests.py +++ b/eos/core/tests.py @@ -17,6 +17,7 @@ from eos.core.bigint import * from eos.core.objects import * from eos.core.hashing import * +from eos.core.tasks import * # Common library things # =================== @@ -26,6 +27,11 @@ class EosTestCase: def setUpClass(cls): pass + @classmethod + def db_connect_and_reset(cls): + db_connect('test') + dbinfo.provider.reset_db() + def assertTrue(self, a): if is_python: self.impl.assertTrue(a) @@ -120,3 +126,37 @@ class BigIntTestCase(EosTestCase): self.assertEqual(pow(bigint1, bigint2), 5**10) self.assertEqual(pow(bigint1, bigint2, bigint3), (5**10)%15) self.assertEqual(pow(bigint1, 10, 15), (5**10)%15) + +class TaskTestCase(EosTestCase): + @classmethod + def setUpClass(cls): + cls.db_connect_and_reset() + + def test_normal(self): + class TaskNormal(Task): + result = StringField() + def _run(self): + self.messages.append('Hello World') + self.result = 'Success' + + task = TaskNormal(run_strategy=DirectRunStrategy()) + task.save() + task.run() + + self.assertEqual(task.status, Task.Status.COMPLETE) + self.assertEqual(len(task.messages), 1) + self.assertEqual(task.messages[0], 'Hello World') + self.assertEqual(task.result, 'Success') + + def test_error(self): + class TaskError(Task): + def _run(self): + raise Exception('Test exception') + + task = TaskError(run_strategy=DirectRunStrategy()) + task.save() + task.run() + + self.assertEqual(task.status, Task.Status.FAILED) + self.assertEqual(len(task.messages), 1) + self.assertTrue('Test exception' in task.messages[0]) diff --git a/eos/psr/tests.py b/eos/psr/tests.py index f4c8d49..5b47db5 100644 --- a/eos/psr/tests.py +++ b/eos/psr/tests.py @@ -216,8 +216,7 @@ class MixnetTestCase(EosTestCase): class ElectionTestCase(EosTestCase): @classmethod def setUpClass(cls): - db_connect('test') - dbinfo.provider.reset_db() + cls.db_connect_and_reset() def do_task_assert(self, election, task, next_task): self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY)