diff --git a/eos/base/tests.py b/eos/base/tests.py
index 45e2dff..be6d865 100644
--- a/eos/base/tests.py
+++ b/eos/base/tests.py
@@ -23,8 +23,7 @@ from eos.core.objects import *
class ElectionTestCase(EosTestCase):
@classmethod
def setUpClass(cls):
- db_connect('test')
- dbinfo.provider.reset_db()
+ cls.db_connect_and_reset()
def do_task_assert(self, election, task, next_task):
self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY)
diff --git a/eos/core/db/__init__.py b/eos/core/db/__init__.py
index 61c5249..58272ba 100644
--- a/eos/core/db/__init__.py
+++ b/eos/core/db/__init__.py
@@ -20,3 +20,36 @@ class DBProvider:
def __init__(self, db_name, db_uri):
self.db_name = db_name
self.db_uri = db_uri
+
+ def connect(self):
+ raise Exception('Not implemented')
+
+ def get_all(self, collection):
+ raise Exception('Not implemented')
+
+ def get_by_id(self, collection, _id):
+ raise Exception('Not implemented')
+
+ def update_by_id(self, collection, _id, value):
+ raise Exception('Not implemented')
+
+ def reset_db(self):
+ raise Exception('Not implemented')
+
+class DummyProvider(DBProvider):
+ def connect(self):
+ pass
+
+ def get_all(self, collection):
+ pass
+
+ def get_by_id(self, collection, _id):
+ pass
+
+ def update_by_id(self, collection, _id, value):
+ pass
+
+ def reset_db(self):
+ pass
+
+db_providers['dummy'] = DummyProvider
diff --git a/eos/core/objects/__init__.py b/eos/core/objects/__init__.py
index 053c0b9..8c7ed9a 100644
--- a/eos/core/objects/__init__.py
+++ b/eos/core/objects/__init__.py
@@ -25,9 +25,10 @@ except:
# Libraries
# =========
+import eos.core.db
+
if is_python:
__pragma__('skip')
- import eos.core.db
import eos.core.db.mongodb
import eos.core.db.postgresql
@@ -54,7 +55,7 @@ else:
class DBInfo:
def __init__(self):
- self.provider = None
+ self.provider = eos.core.db.DummyProvider(None, None)
dbinfo = DBInfo()
diff --git a/eos/core/tasks/__init__.py b/eos/core/tasks/__init__.py
new file mode 100644
index 0000000..a8c5522
--- /dev/null
+++ b/eos/core/tasks/__init__.py
@@ -0,0 +1,63 @@
+# Eos - Verifiable elections
+# Copyright © 2017 RunasSudo (Yingtong Li)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from eos.core.objects import *
+
+class Task(TopLevelObject):
+ class Status:
+ UNKNOWN = 0
+
+ READY = 20
+ PROCESSING = 30
+ COMPLETE = 50
+
+ FAILED = -10
+ TIMEOUT = -20
+
+ _id = UUIDField()
+ status = IntField(default=0)
+ run_strategy = EmbeddedObjectField()
+ messages = ListField(StringField())
+
+ def run(self):
+ self.run_strategy.run(self)
+
+ def _run(self):
+ pass
+
+class RunStrategy(DocumentObject):
+ def run(self, task):
+ raise Exception('Not implemented')
+
+class DirectRunStrategy(RunStrategy):
+ def run(self, task):
+ task.status = Task.Status.PROCESSING
+ task.save()
+
+ try:
+ task._run()
+ task.status = Task.Status.COMPLETE
+ task.save()
+ except Exception as e:
+ task.status = Task.Status.FAILED
+ if is_python:
+ #__pragma__('skip')
+ import traceback
+ #__pragma__('noskip')
+ task.messages.append(traceback.format_exc())
+ else:
+ task.messages.append(repr(e))
+ task.save()
diff --git a/eos/core/tasks/direct.py b/eos/core/tasks/direct.py
new file mode 100644
index 0000000..f5138fa
--- /dev/null
+++ b/eos/core/tasks/direct.py
@@ -0,0 +1,19 @@
+# Eos - Verifiable elections
+# Copyright © 2017 RunasSudo (Yingtong Li)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from eos.core.tasks import *
+
+
diff --git a/eos/core/tests.py b/eos/core/tests.py
index 93b623d..a8ed751 100644
--- a/eos/core/tests.py
+++ b/eos/core/tests.py
@@ -17,6 +17,7 @@
from eos.core.bigint import *
from eos.core.objects import *
from eos.core.hashing import *
+from eos.core.tasks import *
# Common library things
# ===================
@@ -26,6 +27,11 @@ class EosTestCase:
def setUpClass(cls):
pass
+ @classmethod
+ def db_connect_and_reset(cls):
+ db_connect('test')
+ dbinfo.provider.reset_db()
+
def assertTrue(self, a):
if is_python:
self.impl.assertTrue(a)
@@ -120,3 +126,37 @@ class BigIntTestCase(EosTestCase):
self.assertEqual(pow(bigint1, bigint2), 5**10)
self.assertEqual(pow(bigint1, bigint2, bigint3), (5**10)%15)
self.assertEqual(pow(bigint1, 10, 15), (5**10)%15)
+
+class TaskTestCase(EosTestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.db_connect_and_reset()
+
+ def test_normal(self):
+ class TaskNormal(Task):
+ result = StringField()
+ def _run(self):
+ self.messages.append('Hello World')
+ self.result = 'Success'
+
+ task = TaskNormal(run_strategy=DirectRunStrategy())
+ task.save()
+ task.run()
+
+ self.assertEqual(task.status, Task.Status.COMPLETE)
+ self.assertEqual(len(task.messages), 1)
+ self.assertEqual(task.messages[0], 'Hello World')
+ self.assertEqual(task.result, 'Success')
+
+ def test_error(self):
+ class TaskError(Task):
+ def _run(self):
+ raise Exception('Test exception')
+
+ task = TaskError(run_strategy=DirectRunStrategy())
+ task.save()
+ task.run()
+
+ self.assertEqual(task.status, Task.Status.FAILED)
+ self.assertEqual(len(task.messages), 1)
+ self.assertTrue('Test exception' in task.messages[0])
diff --git a/eos/psr/tests.py b/eos/psr/tests.py
index f4c8d49..5b47db5 100644
--- a/eos/psr/tests.py
+++ b/eos/psr/tests.py
@@ -216,8 +216,7 @@ class MixnetTestCase(EosTestCase):
class ElectionTestCase(EosTestCase):
@classmethod
def setUpClass(cls):
- db_connect('test')
- dbinfo.provider.reset_db()
+ cls.db_connect_and_reset()
def do_task_assert(self, election, task, next_task):
self.assertEqual(election.workflow.get_task(task).status, WorkflowTask.Status.READY)