Improve alert handling, centralize aiohttp.ClientSession, and fix minor bugs

This commit is contained in:
Emily Doherty 2025-11-06 00:22:21 -08:00
parent 7e387b6cd7
commit faab0d5f7e
3 changed files with 24 additions and 10 deletions

View File

@ -11,7 +11,7 @@ from token_bot.persistant_database import database as pdb
class AlertsController:
def __init__(self, session: aiohttp.ClientSession):
self._pdb: pdb.Database = pdb.Database(session)
self.table = aiodynamo.client.Table = self._pdb.client.table(
self.table: aiodynamo.client.Table = self._pdb.client.table(
os.getenv("ALERTS_TABLE")
)

View File

@ -93,6 +93,7 @@ class Alert:
return (
self.alert_type == other.alert_type
and self.flavor == other.flavor
and self.region == other.region
and self.price == other.price
)

View File

@ -60,6 +60,7 @@ class Tracker(Extension):
self._alerts: AlertsController | None = None
self._tdb: tdb.Database | None = None
self._history_manager: HistoryManager | None = None
self._session: aiohttp.ClientSession | None = None
###################################
# Task Functions #
@ -87,12 +88,10 @@ class Tracker(Extension):
discord_user = await self.bot.fetch_user(user.user_id)
alerts_by_flavor = await gather_alerts_by_flavor(users_alerts[user])
alert_tally = 0
alert_word = "alert"
if alert_tally > 2:
alert_word += 's'
for flavor in alerts_by_flavor:
for _ in alerts_by_flavor[flavor]:
alert_tally += 1
alert_word = "alert" if alert_tally == 1 else "alerts"
embeds = [
Embed(
title="GoblinBot Tracker Alert Triggered",
@ -131,9 +130,11 @@ class Tracker(Extension):
@listen(Startup)
async def on_start(self):
self.bot.logger.log(logging.INFO, "TokenBot Tracker: Initializing")
self._users = UsersController(aiohttp.ClientSession())
self._alerts = AlertsController(aiohttp.ClientSession())
self._tdb = tdb.Database(aiohttp.ClientSession())
# Create a single shared ClientSession for all components
self._session = aiohttp.ClientSession()
self._users = UsersController(self._session)
self._alerts = AlertsController(self._session)
self._tdb = tdb.Database(self._session)
self._history_manager = HistoryManager(self._tdb)
self.bot.logger.log(logging.INFO, "TokenBot Tracker: Initialized")
self.bot.logger.log(logging.INFO, "TokenBot Tracker: Loading Historical Data")
@ -144,6 +145,12 @@ class Tracker(Extension):
self.bot.logger.log(logging.INFO, "TokenBot Tracker: Started")
self.update_data.start()
def extension_unload(self):
"""Clean up resources when the extension is unloaded"""
if self._session and not self._session.closed:
asyncio.create_task(self._session.close())
self.bot.logger.log(logging.INFO, "TokenBot Tracker: ClientSession closed")
@slash_command(
name="register",
description="Register with a new GoblinBot Region for alerts on token price changes.",
@ -285,11 +292,11 @@ class Tracker(Extension):
await ctx.send(f"Selected Flavor: {ctx.values[0]}", ephemeral=True)
@component_callback("high_alert_menu")
async def alert_menu(self, ctx: ComponentContext):
async def high_alert_menu(self, ctx: ComponentContext):
await ctx.send(f"Selected Alert: {ctx.values[0]}", ephemeral=True)
@component_callback("low_alert_menu")
async def alert_menu(self, ctx: ComponentContext):
async def low_alert_menu(self, ctx: ComponentContext):
await ctx.send(f"Selected Alert: {ctx.values[0]}", ephemeral=True)
@component_callback("remove_alert_menu")
@ -327,7 +334,7 @@ class Tracker(Extension):
async def get_current_token(self, ctx: SlashContext, flavor: Flavor) -> str:
user: User = await self._users.get(ctx.user.id)
if user.region.name is None:
if user.region is None:
return (
f"Please register with a region before attempting to list alerts using\n"
"```/register```"
@ -406,6 +413,8 @@ class Tracker(Extension):
)
raise TimeoutError
else:
# Acknowledge the component interaction to avoid 404 Unknown Interaction
await region_component.ctx.defer(edit_origin=True, suppress_error=True)
region_menu.disabled = True
region = Region(region_component.ctx.values[0].lower())
user = User(ctx.user.id, region, subscribed_alerts=[])
@ -432,6 +441,8 @@ class Tracker(Extension):
)
raise TimeoutError
else:
# Acknowledge the component interaction to avoid 404 Unknown Interaction
await flavor_component.ctx.defer(edit_origin=True, suppress_error=True)
flavor = Flavor[flavor_component.ctx.values[0].upper()]
flavor_menu.disabled = True
await flavor_message.edit(context=ctx, components=flavor_menu)
@ -454,6 +465,8 @@ class Tracker(Extension):
)
raise TimeoutError
else:
# Acknowledge the component interaction to avoid 404 Unknown Interaction
await alert_type_component.ctx.defer(edit_origin=True, suppress_error=True)
alert_type = AlertCategory.from_str(alert_type_component.ctx.custom_id)
for button in alert_type_button[0].components:
button.disabled = True