From d8675ce52546243940a1c198d639ebdb342271fb Mon Sep 17 00:00:00 2001 From: nharris Date: Mon, 19 Sep 2022 10:42:23 -0600 Subject: [PATCH] - Actually fix the bot placeholder message (stupid return statement) (closes #40) - Fix UseEmbed config parsing - Add return types to functions for more readability - Simplify start-up message search --- modules/config_parser.py | 2 +- modules/discord_connector.py | 70 +++++++++++++++--------------- modules/tautulli_connector.py | 80 +++++++++++++++++------------------ modules/utils.py | 10 ++--- 4 files changed, 81 insertions(+), 81 deletions(-) diff --git a/modules/config_parser.py b/modules/config_parser.py index cc918a8..ef562bb 100644 --- a/modules/config_parser.py +++ b/modules/config_parser.py @@ -186,7 +186,7 @@ def _customization(self) -> ConfigSection: @property def use_embeds(self) -> bool: - value = self._get_value(key="UseEmbeds", default=False, env_name_override="TC_USE_EMBEDS") + value = self._customization._get_value(key="UseEmbeds", default=False, env_name_override="TC_USE_EMBEDS") return _extract_bool(value) diff --git a/modules/discord_connector.py b/modules/discord_connector.py index 605f0a9..0ed9e42 100644 --- a/modules/discord_connector.py +++ b/modules/discord_connector.py @@ -1,5 +1,6 @@ import asyncio import sys +from typing import Union import discord from discord.ext import tasks @@ -65,15 +66,15 @@ async def add_emoji_number_reactions(message: discord.Message, count: int): await message.add_reaction(statics.emoji_numbers[i]) -async def send_starter_message(tautulli_connector, discord_channel): +async def send_starter_message(tautulli_connector, discord_channel: discord.TextChannel) -> discord.Message: if tautulli_connector.use_embeds: embed = discord.Embed(title="Welcome to Tauticord!") embed.add_field(name="Starting up...", value='This will be replaced once we get data.', inline=False) - await discord_channel.send(embed=embed) + return await discord_channel.send(content=None, embed=embed) else: - await discord_channel.send(content="Welcome to Tauticord!") + return await discord_channel.send(content="Welcome to Tauticord!") async def send_message(content: TautulliDataResponse, embed: bool = False, message: discord.Message = None, @@ -92,24 +93,24 @@ async def send_message(content: TautulliDataResponse, embed: bool = False, messa if message: # if message exists, use it to edit the message if embed: # let's send an embed if not content.embed: # oops, no embed to send - await message.edit(content="Placeholder", embed=None) # erase any existing content and embeds + await message.edit(content="Something went wrong.", embed=None) # erase any existing content and embeds else: await message.edit(content=None, embed=content.embed) # erase any existing content and embeds else: # let's send a normal message if not content.message: # oops, no message to send - await message.edit(content="Placeholder", embed=None) # erase any existing content and embeds + await message.edit(content="Something went wrong.", embed=None) # erase any existing content and embeds else: await message.edit(content=content.message, embed=None) # erase any existing content and embeds return message else: # otherwise, send a new message in the channel if embed: # let's send an embed if not content.embed: # oops, no embed to send - return await channel.send(content="Placeholder") + return await channel.send(content="Something went wrong.") else: return await channel.send(content=None, embed=content.embed) else: # let's send a normal message if not content.message: # oops, no message to send - return await channel.send(content="Placeholder") + return await channel.send(content="Something went wrong.") else: return await channel.send(content=content.message) @@ -129,7 +130,7 @@ def __init__(self, self.owner_id = owner_id self._refresh_time = refresh_time self.tautulli_channel_name = tautulli_channel_name - self.tautulli_channel = None + self.tautulli_channel: discord.TextChannel = None self.tautulli = tautulli_connector self.analytics = analytics self.use_embeds = use_embeds @@ -140,19 +141,19 @@ def __init__(self, def refresh_time(self) -> int: return max([5, self._refresh_time]) # minimum 5-second sleep time hard-coded, trust me, don't DDoS your server - async def on_ready(self): + async def on_ready(self) -> None: info('Connected to Discord.') self.update_libraries.start() await start_bot(discord_connector=self, analytics=self.analytics) - def connect(self): + def connect(self) -> None: info('Connecting to Discord...') self.client.run(self.token) - def is_me(self, message): + def is_me(self, message) -> bool: return message.author == self.client.user - async def edit_message(self, previous_message): + async def edit_message(self, previous_message) -> discord.Message: """ Collect new summary info, replace old message with new one :param previous_message: discord.Message to replace @@ -249,14 +250,16 @@ def check(reaction, user): await asyncio.sleep(self.refresh_time) return new_message - async def get_tautulli_channel(self): + async def get_tautulli_channel(self) -> None: info(f"Getting {self.tautulli_channel_name} channel") - self.tautulli_channel = await self.get_discord_channel_by_name(channel_name=self.tautulli_channel_name) + self.tautulli_channel: discord.TextChannel = \ + await self.get_discord_channel_by_name(channel_name=self.tautulli_channel_name) if not self.tautulli_channel: raise Exception(f"Could not load {self.tautulli_channel_name} channel. Exiting...") info(f"{self.tautulli_channel_name} channel collected.") - async def get_discord_channel_by_starting_name(self, starting_channel_name: str, channel_type: str = "text"): + async def get_discord_channel_by_starting_name(self, starting_channel_name: str, channel_type: str = "text") -> \ + Union[discord.VoiceChannel, discord.TextChannel]: for channel in self.client.get_all_channels(): if channel.name.startswith(starting_channel_name): return channel @@ -269,7 +272,8 @@ async def get_discord_channel_by_starting_name(self, starting_channel_name: str, except: raise Exception(f"Could not create channel {starting_channel_name}") - async def get_discord_channel_by_name(self, channel_name: str, channel_type: str = "text"): + async def get_discord_channel_by_name(self, channel_name: str, channel_type: str = "text") -> \ + Union[discord.VoiceChannel, discord.TextChannel]: for channel in self.client.get_all_channels(): if channel.name == channel_name: return channel @@ -283,7 +287,7 @@ async def get_discord_channel_by_name(self, channel_name: str, channel_type: str except: raise Exception(f"Could not create channel {channel_name}") - async def edit_library_voice_channel(self, channel_name: str, count: int): + async def edit_library_voice_channel(self, channel_name: str, count: int) -> None: info(f"Updating {channel_name} voice channel with new library size") channel = await self.get_discord_channel_by_starting_name(starting_channel_name=f"{channel_name}:", channel_type="voice") @@ -295,7 +299,7 @@ async def edit_library_voice_channel(self, channel_name: str, count: int): except Exception as voice_channel_edit_error: pass - async def edit_bandwidth_voice_channel(self, channel_name: str, size: int): + async def edit_bandwidth_voice_channel(self, channel_name: str, size: int) -> None: info(f"Updating {channel_name} voice channel with new bandwidth") channel = await self.get_discord_channel_by_starting_name(starting_channel_name=f"{channel_name}:", channel_type="voice") @@ -307,7 +311,7 @@ async def edit_bandwidth_voice_channel(self, channel_name: str, size: int): except Exception as voice_channel_edit_error: pass - async def edit_stream_count_voice_channel(self, channel_name: str, count: int): + async def edit_stream_count_voice_channel(self, channel_name: str, count: int) -> None: info(f"Updating {channel_name} voice channel with new stream count") channel = await self.get_discord_channel_by_starting_name(starting_channel_name=f"{channel_name}:", channel_type="voice") @@ -319,25 +323,21 @@ async def edit_stream_count_voice_channel(self, channel_name: str, count: int): except Exception as voice_channel_edit_error: pass - async def get_old_message_in_tautulli_channel(self): + async def get_old_message_in_tautulli_channel(self) -> discord.Message: """ Get the last message sent in the Tautulli channel, used to start the bot loop :return: discord.Message """ - last_bot_message_id = "" - while last_bot_message_id == "": - async for msg in self.tautulli_channel.history(limit=1): - print(msg) - if msg.author == self.client.user: - last_bot_message_id = msg.id - await msg.clear_reactions() - break - if last_bot_message_id == "": - info("Couldn't find old message, sending initial message...") - await send_starter_message(tautulli_connector=self.tautulli, discord_channel=self.tautulli_channel) - return await self.tautulli_channel.fetch_message(last_bot_message_id) - - async def update_voice_channels(self, activity): + # If the very last message in the channel is from Tauticord, use it + async for msg in self.tautulli_channel.history(limit=1): + if msg.author == self.client.user: + await msg.clear_reactions() + return msg + # If the very last message in the channel is not from Tauticord, make a new one. + info("Couldn't find old message, sending initial message...") + return await send_starter_message(tautulli_connector=self.tautulli, discord_channel=self.tautulli_channel) + + async def update_voice_channels(self, activity) -> None: if activity: if self.tautulli.voice_channel_settings.get('count', False): await self.edit_stream_count_voice_channel(channel_name="Current Streams", count=activity.stream_count) @@ -354,7 +354,7 @@ async def update_voice_channels(self, activity): size=activity.wan_bandwidth) @tasks.loop(hours=1.0) - async def update_libraries(self): + async def update_libraries(self) -> None: if self.tautulli.voice_channel_settings.get('stats', False): for library_name in self.tautulli.voice_channel_settings.get('libraries', []): size = self.tautulli.get_library_item_count(library_name=library_name) diff --git a/modules/tautulli_connector.py b/modules/tautulli_connector.py index 0462ea8..8e013f5 100644 --- a/modules/tautulli_connector.py +++ b/modules/tautulli_connector.py @@ -2,6 +2,7 @@ import discord import tautulli +from tautulli.models.activity import Session import modules.statics as statics from modules import utils @@ -16,7 +17,7 @@ def __init__(self, activity_data, time_settings: dict): self._time_settings = time_settings @property - def stream_count(self): + def stream_count(self) -> int: value = self._data.get('stream_count', 0) try: return int(value) @@ -24,7 +25,7 @@ def stream_count(self): return 0 @property - def transcode_count(self): + def transcode_count(self) -> int: value = self._data.get('stream_count_transcode', 0) try: return int(value) @@ -32,7 +33,7 @@ def transcode_count(self): return 0 @property - def total_bandwidth(self): + def total_bandwidth(self) -> Union[str, None]: value = self._data.get('total_bandwidth', 0) try: return utils.human_bitrate(float(value) * 1024) @@ -40,7 +41,7 @@ def total_bandwidth(self): return None @property - def lan_bandwidth(self): + def lan_bandwidth(self) -> Union[str, None]: value = self._data.get('lan_bandwidth', 0) try: return utils.human_bitrate(float(value) * 1024) @@ -48,7 +49,7 @@ def lan_bandwidth(self): return None @property - def wan_bandwidth(self): + def wan_bandwidth(self) -> Union[str, None]: total = self._data.get('total_bandwidth', 0) lan = self._data.get('lan_bandwidth', 0) value = total - lan @@ -57,10 +58,8 @@ def wan_bandwidth(self): except: return None - - @property - def message(self): + def message(self) -> str: overview_message = "" if self.stream_count > 0: overview_message += statics.sessions_message.format(stream_count=self.stream_count, @@ -77,7 +76,7 @@ def message(self): return overview_message @property - def sessions(self): + def sessions(self) -> List[Session]: return [Session(session_data=session_data, time_settings=self._time_settings) for session_data in self._data.get('sessions', [])] @@ -88,7 +87,7 @@ def __init__(self, session_data, time_settings: dict): self._time_settings = time_settings @property - def duration_milliseconds(self): + def duration_milliseconds(self) -> int: value = self._data.get('duration', 0) try: value = int(value) @@ -97,7 +96,7 @@ def duration_milliseconds(self): return int(value) @property - def location_milliseconds(self): + def location_milliseconds(self) -> int: value = self._data.get('view_offset', 0) try: value = int(value) @@ -106,19 +105,19 @@ def location_milliseconds(self): return int(value) @property - def progress_percentage(self): + def progress_percentage(self) -> int: if not self.duration_milliseconds: return 0 return int(self.location_milliseconds / self.duration_milliseconds) @property - def progress_marker(self): + def progress_marker(self) -> str: current_progress_min_sec = utils.milliseconds_to_minutes_seconds(milliseconds=self.location_milliseconds) total_min_sec = utils.milliseconds_to_minutes_seconds(milliseconds=self.duration_milliseconds) return f"{current_progress_min_sec}/{total_min_sec}" @property - def eta(self): + def eta(self) -> str: if not self.duration_milliseconds or not self.location_milliseconds: return "Unknown" milliseconds_remaining = self.duration_milliseconds - self.location_milliseconds @@ -129,7 +128,7 @@ def eta(self): return eta_string @property - def title(self): + def title(self) -> str: if self._data.get('live'): return f"{self._data.get('grandparent_title', '')} - {self._data['title']}" elif self._data['media_type'] == 'episode': @@ -138,7 +137,7 @@ def title(self): return self._data.get('full_title') @property - def status_icon(self): + def status_icon(self) -> str: """ Get icon for a stream state :return: emoji icon @@ -146,7 +145,7 @@ def status_icon(self): return statics.switcher.get(self._data['state'], "") @property - def type_icon(self): + def type_icon(self) -> str: if self._data['media_type'] in statics.media_type_icons: return statics.media_type_icons[self._data['media_type']] # thanks twilsonco @@ -157,23 +156,23 @@ def type_icon(self): return '🎁' @property - def id(self): + def id(self) -> str: return self._data['session_id'] @property - def username(self): + def username(self) -> str: return self._data['username'] @property - def product(self): + def product(self) -> str: return self._data['product'] @property - def player(self): + def player(self) -> str: return self._data['player'] @property - def quality_profile(self): + def quality_profile(self) -> str: return self._data['quality_profile'] @property @@ -186,26 +185,26 @@ def bandwidth(self) -> str: return utils.human_bitrate(float(value) * 1024) @property - def transcoding_stub(self): + def transcoding_stub(self) -> str: return ' (Transcode)' if self.stream_container_decision == 'transcode' else '' @property - def stream_container_decision(self): + def stream_container_decision(self) -> str: return self._data['stream_container_decision'] - def _session_title(self, session_number: int): + def _session_title(self, session_number: int) -> str: return statics.session_title_message.format(count=statics.emoji_numbers[session_number - 1], icon=self.status_icon, username=self.username, media_type_icon=self.type_icon, title=self.title) - def _session_player(self): + def _session_player(self) -> str: return statics.session_player_message.format(product=self.product, player=self.player) - def _session_details(self): + def _session_details(self) -> str: return statics.session_details_message.format(quality_profile=self.quality_profile, bandwidth=self.bandwidth, transcoding=self.transcoding_stub) - def _session_progress(self): + def _session_progress(self) -> str: return statics.session_progress_message.format(progress=self.progress_marker, eta=self.eta) @@ -215,26 +214,26 @@ def __init__(self, session: Session, session_number: int): self._session_number = session_number @property - def title(self): + def title(self) -> str: try: return self._session._session_title(session_number=self._session_number) except Exception as title_exception: return "Unknown" @property - def player(self): + def player(self) -> str: return self._session._session_player() @property - def details(self): + def details(self) -> str: return self._session._session_details() @property - def progress(self): + def progress(self) -> str: return self._session._session_progress() @property - def body(self): + def body(self) -> str: try: return f"{self.player}\n{self.details}\n{self.progress}" except Exception as body_exception: @@ -251,7 +250,7 @@ def __init__(self, overview_message: str, streams_info: List[TautulliStreamInfo] self.error = error_occurred @property - def embed(self): + def embed(self) -> discord.Embed: if len(self._streams) <= 0: return discord.Embed(title="No current activity") embed = discord.Embed(title=self._overview_message) @@ -262,13 +261,14 @@ def embed(self): return embed @property - def message(self): + def message(self) -> str: if len(self._streams) <= 0: return "No current activity." final_message = f"{self._overview_message}\n" for stream in self._streams: final_message += f"{stream.title}\n{stream.body}\n" final_message += f"\nTo terminate a stream, react with the stream number." + return final_message class TautulliConnector: @@ -291,7 +291,7 @@ def __init__(self, self.voice_channel_settings = voice_channel_settings self.time_settings = time_settings - def _error_and_analytics(self, error_message, function_name): + def _error_and_analytics(self, error_message, function_name) -> None: error(error_message) self.analytics.event(event_category="Error", event_action=function_name, random_uuid_if_needed=True) @@ -327,7 +327,7 @@ def refresh_data(self) -> Tuple[TautulliDataResponse, int, Union[Activity, None] self._error_and_analytics(error_message=e, function_name='refresh_data (KeyError)') return TautulliDataResponse(overview_message="**Connection lost.**", error_occurred=True), 0, None - def stop_stream(self, stream_number): + def stop_stream(self, stream_number) -> str: """ Stop a Plex stream :param stream_number: stream number used to react to Discord message (ex. 1, 2, 3) @@ -345,21 +345,21 @@ def stop_stream(self, stream_number): self._error_and_analytics(error_message=e, function_name='stop_stream') return "Something went wrong." - def get_library_id(self, library_name: str): + def get_library_id(self, library_name: str) -> Union[str, None]: for library in self.api.library_names: if library.get('section_name') == library_name: return library.get('section_id') error(f"Could not get ID for library {library_name}") return None - def get_library_info(self, library_name: str): + def get_library_info(self, library_name: str) -> Union[dict, None]: info(f"Collecting stats about library {library_name}") library_id = self.get_library_id(library_name=library_name) if not library_id: return None return self.api.get_library(section_id=library_id) - def get_library_item_count(self, library_name: str): + def get_library_item_count(self, library_name: str) -> int: library_info = self.get_library_info(library_name=library_name) if not library_info: return 0 diff --git a/modules/utils.py b/modules/utils.py index 09ec1b2..044122a 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -3,20 +3,20 @@ from pytz import timezone -def make_plural(word, count: int, suffix_override: str = 's'): +def make_plural(word, count: int, suffix_override: str = 's') -> str: if count > 1: return f"{word}{suffix_override}" return word -def _human_bitrate(number, denominator: int = 1, letter: str = "", d: int = 1): +def _human_bitrate(number, denominator: int = 1, letter: str = "", d: int = 1) -> str: if d <= 0: return f'{int(number / denominator):d} {letter}bps' else: return f'{float(number / denominator):.{d}f} {letter}bps' -def human_bitrate(_bytes, d: int = 1): +def human_bitrate(_bytes, d: int = 1) -> str: # Return the given bitrate as a human friendly bps, Kbps, Mbps, Gbps, or Tbps string KB = float(1024) @@ -44,7 +44,7 @@ def human_bitrate(_bytes, d: int = 1): return _human_bitrate(_bytes, denominator=denominator, letter=letter, d=d) -def milliseconds_to_minutes_seconds(milliseconds: int): +def milliseconds_to_minutes_seconds(milliseconds: int) -> str: seconds = int(milliseconds / 1000) minutes = int(seconds / 60) if minutes < 10: @@ -55,7 +55,7 @@ def milliseconds_to_minutes_seconds(milliseconds: int): return f"{minutes}:{seconds}" -def now_plus_milliseconds(milliseconds: int, timezone_code: str = None): +def now_plus_milliseconds(milliseconds: int, timezone_code: str = None) -> datetime: if timezone_code: now = datetime.now(timezone(timezone_code)) # will raise exception if invalid timezone_code else: