Implement deletion

This commit is contained in:
Yingtong Li 2018-01-04 16:40:32 +08:00
parent e28a003bae
commit b62933629b
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
4 changed files with 17 additions and 0 deletions

View File

@ -36,6 +36,9 @@ class DBProvider:
def update_by_id(self, collection, _id, value):
raise Exception('Not implemented')
def delete_by_id(self, collection, _id):
raise Exception('Not implemented')
def reset_db(self):
raise Exception('Not implemented')
@ -55,6 +58,9 @@ class DummyProvider(DBProvider):
def update_by_id(self, collection, _id, value):
pass
def delete_by_id(self, collection, _id):
pass
def reset_db(self):
pass

View File

@ -42,6 +42,9 @@ class MongoDBProvider(eos.core.db.DBProvider):
def update_by_id(self, collection, _id, value):
self.db[collection].replace_one({'_id': _id}, value, upsert=True)
def delete_by_id(self, collection, _id):
self.db[collection].delete_one({'_id': _id})
def reset_db(self):
self.client.drop_database(self.db_name)

View File

@ -58,6 +58,11 @@ class PostgreSQLDBProvider(eos.core.db.DBProvider):
self.cur.execute(SQL('INSERT INTO {} (_id, data) VALUES (%s, %s) ON CONFLICT (_id) DO UPDATE SET data = excluded.data').format(Identifier(table)), (_id, psycopg2.extras.Json(value)))
self.conn.commit()
def delete_by_id(self, table, _id):
self.create_table(table)
self.cur.execute(SQL('DELETE FROM {} WHERE _id = %s').format(Identifier(table)), (_id))
self.conn.commit()
def reset_db(self):
self.cur.execute('DROP SCHEMA public CASCADE; CREATE SCHEMA public; GRANT ALL ON SCHEMA public TO postgres; GRANT ALL ON SCHEMA public TO public')
self.conn.commit()

View File

@ -481,6 +481,9 @@ class TopLevelObject(DocumentObject, metaclass=TopLevelObjectType):
#res = dbinfo.db[self._db_name].replace_one({'_id': self._fields['_id'].serialise(self._id)}, EosObject.serialise_and_wrap(self), upsert=True)
dbinfo.provider.update_by_id(self._db_name, self._fields['_id'].serialise(self._id), EosObject.serialise_and_wrap(self))
def delete(self):
dbinfo.provider.delete_by_id(self._db_name, self._fields['_id'].serialise(self._id))
@classmethod
def get_all(cls):
return [EosObject.deserialise_and_unwrap(x) for x in dbinfo.provider.get_all(cls._db_name)]