diff --git a/token_bot/controller/alerts.py b/token_bot/controller/alerts.py index f032351..1626e55 100644 --- a/token_bot/controller/alerts.py +++ b/token_bot/controller/alerts.py @@ -25,16 +25,19 @@ class AlertsController: return Alert.from_str(alert) return alert - async def add_user(self, user: int | User, alert: str | Alert) -> None: - user = self._user_to_obj(user) + async def get_users(self, alert: str | Alert) -> List[User]: alert = self._alert_to_obj(alert) + await alert.get(self.table, consistent=True) + return alert.users + + async def add_user(self, alert: str | Alert, user: int | User): + alert = self._alert_to_obj(alert) + user = self._user_to_obj(user) await alert.add_user(self.table, user) - async def delete_user(self, user: int | User, alert: str | Alert): - user = self._user_to_obj(user) + async def remove_user(self, alert: str | Alert, user: int | User): alert = self._alert_to_obj(alert) + user = self._user_to_obj(user) await alert.remove_user(self.table, user) - async def get_users(self, alert: str | Alert, consistent: bool = False) -> List[User]: - alert = self._alert_to_obj(alert) - return await alert.get_users(self.table, consistent=consistent ) + diff --git a/token_bot/controller/users.py b/token_bot/controller/users.py index 35733e2..4057ced 100644 --- a/token_bot/controller/users.py +++ b/token_bot/controller/users.py @@ -60,3 +60,11 @@ class UsersController: await user.get(self.table) user.subscribed_alerts.append(alert) await user.put(self.table) + + async def remove_alert(self, user: int | User, alert: str | Alert) -> None: + user = self._user_to_obj(user) + alert = self._alert_to_obj(alert) + await user.get(self.table) + user.subscribed_alerts.remove(alert) + await user.put(self.table) + diff --git a/token_bot/persistant_database/alert_schema.py b/token_bot/persistant_database/alert_schema.py index 9487df1..adb3de9 100644 --- a/token_bot/persistant_database/alert_schema.py +++ b/token_bot/persistant_database/alert_schema.py @@ -19,7 +19,7 @@ class Alert: self.region: Region = region self.price: int = price self._loaded: bool = False - self._users: List[pdb.User] = [] + self.users: List[pdb.User] = [] @classmethod def from_item(cls, primary_key: int, sort_key: str, users: List[int]) -> 'Alert': @@ -82,18 +82,15 @@ class Alert: await self.get(table, consistent=consistent) async def _append_user(self, table: Table, user: pdb.User) -> None: - self._users.append(user) + self.users.append(user) await self.put(table) async def _remove_user(self, table: Table, user: pdb.User) -> None: - update_expression = F("users").delete({user.user_id}) - await table.update_item( - key=self.key, - update_expression=update_expression - ) + self.users.remove(user) + await self.put(table) async def put(self, table: Table) -> None: - user_ids = [str(user.user_id) for user in self._users] + user_ids = [str(user.user_id) for user in self.users] await table.put_item( item={ self.primary_key_name: self.primary_key, @@ -110,23 +107,23 @@ class Alert: ) except ItemNotFound: return False - self._users = [pdb.User(int(user_id)) for user_id in response['users']] + self.users = [pdb.User(int(user_id)) for user_id in response['users']] self._loaded = True return True async def get_users(self, table: Table, consistent: bool = False) -> List[pdb.User]: await self._lazy_load(table, consistent=consistent) - return self._users + return self.users async def add_user(self, table: Table, user: pdb.User, consistent: bool = False) -> None: await self._lazy_load(table, consistent=consistent) - if user not in self._users: + if user not in self.users: await self._append_user(table=table, user=user) async def remove_user(self, table: Table, user: pdb.User, consistent: bool = True) -> None: await self._lazy_load(table, consistent=consistent) - if user in self._users: + if user in self.users: await self._remove_user(table=table, user=user) diff --git a/token_bot/persistant_database/user_schema.py b/token_bot/persistant_database/user_schema.py index 70b51b5..80399ee 100644 --- a/token_bot/persistant_database/user_schema.py +++ b/token_bot/persistant_database/user_schema.py @@ -59,9 +59,6 @@ class User: async def delete(self, table: Table) -> None: if not self._loaded: await self._lazy_load(table, consistent=True) - if self.subscribed_alerts: - for alert in self.subscribed_alerts: - await alert.remove_user(table, self) await table.delete_item( key={self.primary_key_name: self.primary_key}, ) @@ -80,3 +77,15 @@ class User: self.subscribed_alerts.append(pdb.Alert.from_str(string_trinity)) self.region = Region(response['region']) return True + + async def add_alert(self, table: Table, alert: 'pdb.Alert', consistent: bool = False) -> None: + await self._lazy_load(table, consistent=consistent) + if alert not in self.subscribed_alerts: + self.subscribed_alerts.append(alert) + await self.put(table) + + async def remove_alert(self, table: Table, alert: 'pdb.Alert', consistent: bool = True) -> None: + await self._lazy_load(table, consistent=consistent) + if alert in self.subscribed_alerts: + self.subscribed_alerts.remove(alert) + await self.put(table)