diff --git a/token_bot/controller/alerts.py b/token_bot/controller/alerts.py index 2865c9b..b01826b 100644 --- a/token_bot/controller/alerts.py +++ b/token_bot/controller/alerts.py @@ -11,7 +11,7 @@ from token_bot.persistant_database import database as pdb class AlertsController: def __init__(self, session: aiohttp.ClientSession): self._pdb: pdb.Database = pdb.Database(session) - self.table = aiodynamo.client.Table = self._pdb.client.table( + self.table: aiodynamo.client.Table = self._pdb.client.table( os.getenv("ALERTS_TABLE") ) diff --git a/token_bot/persistant_database/alert_schema.py b/token_bot/persistant_database/alert_schema.py index 4890e4d..f1a3047 100644 --- a/token_bot/persistant_database/alert_schema.py +++ b/token_bot/persistant_database/alert_schema.py @@ -93,6 +93,7 @@ class Alert: return ( self.alert_type == other.alert_type and self.flavor == other.flavor + and self.region == other.region and self.price == other.price ) diff --git a/token_bot/tracker.py b/token_bot/tracker.py index f4db1a7..50275a8 100644 --- a/token_bot/tracker.py +++ b/token_bot/tracker.py @@ -60,6 +60,7 @@ class Tracker(Extension): self._alerts: AlertsController | None = None self._tdb: tdb.Database | None = None self._history_manager: HistoryManager | None = None + self._session: aiohttp.ClientSession | None = None ################################### # Task Functions # @@ -87,12 +88,10 @@ class Tracker(Extension): discord_user = await self.bot.fetch_user(user.user_id) alerts_by_flavor = await gather_alerts_by_flavor(users_alerts[user]) alert_tally = 0 - alert_word = "alert" - if alert_tally > 2: - alert_word += 's' for flavor in alerts_by_flavor: for _ in alerts_by_flavor[flavor]: alert_tally += 1 + alert_word = "alert" if alert_tally == 1 else "alerts" embeds = [ Embed( title="GoblinBot Tracker Alert Triggered", @@ -131,9 +130,11 @@ class Tracker(Extension): @listen(Startup) async def on_start(self): self.bot.logger.log(logging.INFO, "TokenBot Tracker: Initializing") - self._users = UsersController(aiohttp.ClientSession()) - self._alerts = AlertsController(aiohttp.ClientSession()) - self._tdb = tdb.Database(aiohttp.ClientSession()) + # Create a single shared ClientSession for all components + self._session = aiohttp.ClientSession() + self._users = UsersController(self._session) + self._alerts = AlertsController(self._session) + self._tdb = tdb.Database(self._session) self._history_manager = HistoryManager(self._tdb) self.bot.logger.log(logging.INFO, "TokenBot Tracker: Initialized") self.bot.logger.log(logging.INFO, "TokenBot Tracker: Loading Historical Data") @@ -144,6 +145,12 @@ class Tracker(Extension): self.bot.logger.log(logging.INFO, "TokenBot Tracker: Started") self.update_data.start() + def extension_unload(self): + """Clean up resources when the extension is unloaded""" + if self._session and not self._session.closed: + asyncio.create_task(self._session.close()) + self.bot.logger.log(logging.INFO, "TokenBot Tracker: ClientSession closed") + @slash_command( name="register", description="Register with a new GoblinBot Region for alerts on token price changes.", @@ -285,11 +292,11 @@ class Tracker(Extension): await ctx.send(f"Selected Flavor: {ctx.values[0]}", ephemeral=True) @component_callback("high_alert_menu") - async def alert_menu(self, ctx: ComponentContext): + async def high_alert_menu(self, ctx: ComponentContext): await ctx.send(f"Selected Alert: {ctx.values[0]}", ephemeral=True) @component_callback("low_alert_menu") - async def alert_menu(self, ctx: ComponentContext): + async def low_alert_menu(self, ctx: ComponentContext): await ctx.send(f"Selected Alert: {ctx.values[0]}", ephemeral=True) @component_callback("remove_alert_menu") @@ -327,7 +334,7 @@ class Tracker(Extension): async def get_current_token(self, ctx: SlashContext, flavor: Flavor) -> str: user: User = await self._users.get(ctx.user.id) - if user.region.name is None: + if user.region is None: return ( f"Please register with a region before attempting to list alerts using\n" "```/register```" @@ -406,6 +413,8 @@ class Tracker(Extension): ) raise TimeoutError else: + # Acknowledge the component interaction to avoid 404 Unknown Interaction + await region_component.ctx.defer(edit_origin=True, suppress_error=True) region_menu.disabled = True region = Region(region_component.ctx.values[0].lower()) user = User(ctx.user.id, region, subscribed_alerts=[]) @@ -432,6 +441,8 @@ class Tracker(Extension): ) raise TimeoutError else: + # Acknowledge the component interaction to avoid 404 Unknown Interaction + await flavor_component.ctx.defer(edit_origin=True, suppress_error=True) flavor = Flavor[flavor_component.ctx.values[0].upper()] flavor_menu.disabled = True await flavor_message.edit(context=ctx, components=flavor_menu) @@ -454,6 +465,8 @@ class Tracker(Extension): ) raise TimeoutError else: + # Acknowledge the component interaction to avoid 404 Unknown Interaction + await alert_type_component.ctx.defer(edit_origin=True, suppress_error=True) alert_type = AlertCategory.from_str(alert_type_component.ctx.custom_id) for button in alert_type_button[0].components: button.disabled = True