From ab362b97cd8f2b3b281afa0ef2686d2686d3459f Mon Sep 17 00:00:00 2001 From: github-actions <${GITHUB_ACTOR}@users.noreply.github.com> Date: Sat, 14 Jan 2023 17:01:24 +0000 Subject: [PATCH] Format Python code with psf/black push --- cogs/commands.py | 44 ++++---- cogs/image_service_cog.py | 9 +- cogs/moderations_service_cog.py | 8 +- cogs/text_service_cog.py | 54 +++++---- gpt3discord.py | 1 - services/image_service.py | 23 ++-- services/text_service.py | 189 ++++++++++++++++++++++---------- 7 files changed, 203 insertions(+), 125 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index 6815f277..ec039095 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -8,7 +8,6 @@ ALLOWED_GUILDS = EnvService.get_allowed_guilds() - class Commands(discord.Cog, name="Commands"): def __init__( self, @@ -97,7 +96,6 @@ async def settings( async def local_size(self, ctx: discord.ApplicationContext): await self.image_draw_cog.local_size_command(ctx) - @add_to_group("system") @discord.slash_command( name="clear-local", @@ -108,7 +106,6 @@ async def local_size(self, ctx: discord.ApplicationContext): async def clear_local(self, ctx: discord.ApplicationContext): await self.image_draw_cog.clear_local_command(ctx) - @add_to_group("system") @discord.slash_command( name="usage", @@ -119,7 +116,6 @@ async def clear_local(self, ctx: discord.ApplicationContext): async def usage(self, ctx: discord.ApplicationContext): await self.converser_cog.usage_command(ctx) - @add_to_group("system") @discord.slash_command( name="set-usage", @@ -134,7 +130,6 @@ async def usage(self, ctx: discord.ApplicationContext): async def set_usage(self, ctx: discord.ApplicationContext, usage_amount: float): await self.converser_cog.set_usage_command(ctx, usage_amount) - @add_to_group("system") @discord.slash_command( name="delete-conversation-threads", @@ -163,7 +158,6 @@ async def delete_all_conversation_threads(self, ctx: discord.ApplicationContext) async def moderations_test(self, ctx: discord.ApplicationContext, prompt: str): await self.moderations_cog.moderations_test_command(ctx, prompt) - @add_to_group("mod") @discord.slash_command( name="set", @@ -241,8 +235,9 @@ async def ask( frequency_penalty: float, presence_penalty: float, ): - await self.converser_cog.ask_command(ctx, prompt, temperature, top_p, frequency_penalty, presence_penalty) - + await self.converser_cog.ask_command( + ctx, prompt, temperature, top_p, frequency_penalty, presence_penalty + ) @add_to_group("gpt") @discord.slash_command( @@ -251,10 +246,15 @@ async def ask( guild_ids=ALLOWED_GUILDS, ) @discord.option( - name="instruction", description="How you want GPT3 to edit the text", required=True + name="instruction", + description="How you want GPT3 to edit the text", + required=True, ) @discord.option( - name="input", description="The text you want to edit, can be empty", required=False, default="" + name="input", + description="The text you want to edit, can be empty", + required=False, + default="", ) @discord.option( name="temperature", @@ -273,10 +273,7 @@ async def ask( max_value=1, ) @discord.option( - name="codex", - description="Enable codex version", - required=False, - default=False + name="codex", description="Enable codex version", required=False, default=False ) @discord.guild_only() async def edit( @@ -288,8 +285,9 @@ async def edit( top_p: float, codex: bool, ): - await self.converser_cog.edit_command(ctx, instruction, input, temperature, top_p, codex) - + await self.converser_cog.edit_command( + ctx, instruction, input, temperature, top_p, codex + ) @add_to_group("gpt") @discord.slash_command( @@ -329,8 +327,9 @@ async def converse( private: bool, minimal: bool, ): - await self.converser_cog.converse_command(ctx, opener, opener_file, private, minimal) - + await self.converser_cog.converse_command( + ctx, opener, opener_file, private, minimal + ) @add_to_group("gpt") @discord.slash_command( @@ -356,7 +355,6 @@ async def end(self, ctx: discord.ApplicationContext): async def draw(self, ctx: discord.ApplicationContext, prompt: str): await self.image_draw_cog.draw_command(ctx, prompt) - @add_to_group("dalle") @discord.slash_command( name="optimize", @@ -375,15 +373,14 @@ async def optimize(self, ctx: discord.ApplicationContext, prompt: str): """ @discord.slash_command( - name="private-test", - description="Private thread for testing. Only visible to you and server admins.", - guild_ids=ALLOWED_GUILDS, + name="private-test", + description="Private thread for testing. Only visible to you and server admins.", + guild_ids=ALLOWED_GUILDS, ) @discord.guild_only() async def private_test(self, ctx: discord.ApplicationContext): await self.converser_cog.private_test_command(ctx) - @discord.slash_command( name="help", description="Get help for GPT3Discord", guild_ids=ALLOWED_GUILDS ) @@ -391,7 +388,6 @@ async def private_test(self, ctx: discord.ApplicationContext): async def help(self, ctx: discord.ApplicationContext): await self.converser_cog.help_command(ctx) - @discord.slash_command( name="setup", description="Setup your API key for use with GPT3Discord", diff --git a/cogs/image_service_cog.py b/cogs/image_service_cog.py index d14ca9fe..8698e028 100644 --- a/cogs/image_service_cog.py +++ b/cogs/image_service_cog.py @@ -40,11 +40,12 @@ def __init__( print("Draw service initialized") self.redo_users = {} - async def draw_command(self, ctx: discord.ApplicationContext, prompt: str): user_api_key = None if USER_INPUT_API_KEYS: - user_api_key = await TextService.get_user_api_key(ctx.user.id, ctx, USER_KEY_DB) + user_api_key = await TextService.get_user_api_key( + ctx.user.id, ctx, USER_KEY_DB + ) if not user_api_key: return @@ -58,8 +59,7 @@ async def draw_command(self, ctx: discord.ApplicationContext, prompt: str): try: asyncio.ensure_future( ImageService.encapsulated_send( - self, - user.id, prompt, ctx, custom_api_key=user_api_key + self, user.id, prompt, ctx, custom_api_key=user_api_key ) ) @@ -98,4 +98,3 @@ async def clear_local_command(self, ctx): print(e) await ctx.respond("Local images cleared.") - diff --git a/cogs/moderations_service_cog.py b/cogs/moderations_service_cog.py index f44d48fc..a65f074c 100644 --- a/cogs/moderations_service_cog.py +++ b/cogs/moderations_service_cog.py @@ -14,6 +14,7 @@ print("Failed to retrieve the General and Moderations DB") raise e + class ModerationsService(discord.Cog, name="ModerationsService"): def __init__( self, @@ -32,6 +33,7 @@ def __init__( self.moderation_enabled_guilds = [] self.moderation_tasks = {} self.moderations_launched = [] + @discord.Cog.listener() async def on_ready(self): # Check moderation service for each guild @@ -81,6 +83,7 @@ async def check_and_launch_moderations(self, guild_id, alert_channel_override=No return moderations_channel return None + async def moderations_command( self, ctx: discord.ApplicationContext, status: str, alert_channel_id: str ): @@ -118,9 +121,10 @@ async def moderations_command( Moderation.moderations_launched.remove(ctx.guild_id) await ctx.respond("Moderations service disabled") - async def moderations_test_command(self, ctx: discord.ApplicationContext, prompt: str): + async def moderations_test_command( + self, ctx: discord.ApplicationContext, prompt: str + ): await ctx.defer() response = await self.model.send_moderations_request(prompt) await ctx.respond(response["results"][0]["category_scores"]) await ctx.send_followup(response["results"][0]["flagged"]) - diff --git a/cogs/text_service_cog.py b/cogs/text_service_cog.py index a32beae7..907fee39 100644 --- a/cogs/text_service_cog.py +++ b/cogs/text_service_cog.py @@ -112,7 +112,6 @@ def __init__( self.conversation_threads = {} self.summarize = self.model.summarize_conversations - # Pinecone data self.pinecone_service = pinecone_service @@ -157,7 +156,6 @@ def __init__( self.message_queue = message_queue self.conversation_thread_owners = {} - async def load_file(self, file, ctx): try: async with aiofiles.open(file, "r") as f: @@ -212,7 +210,6 @@ async def on_ready(self): ) print(f"Commands synced") - # TODO: add extra condition to check if multi is enabled for the thread, stated in conversation_threads def check_conversing(self, user_id, channel_id, message_content, multi=None): cond1 = channel_id in self.conversation_threads @@ -298,7 +295,6 @@ async def end_conversation( traceback.print_exc() pass - async def send_settings_text(self, ctx): embed = discord.Embed( title="GPT3Bot Settings", @@ -386,10 +382,12 @@ async def paginate_and_send(self, response_text, ctx): async def paginate_embed(self, response_text, codex, prompt=None, instruction=None): - if codex: #clean codex input + if codex: # clean codex input response_text = response_text.replace("```", "") response_text = response_text.replace(f"***Prompt: {prompt}***\n", "") - response_text = response_text.replace(f"***Instruction: {instruction}***\n\n", "") + response_text = response_text.replace( + f"***Instruction: {instruction}***\n\n", "" + ) response_text = [ response_text[i : i + self.EMBED_CUTOFF] @@ -400,12 +398,20 @@ async def paginate_embed(self, response_text, codex, prompt=None, instruction=No # Send each chunk as a message for count, chunk in enumerate(response_text, start=1): if not first: - page = discord.Embed(title=f"Page {count}", description=chunk if not codex else f"***Prompt:{prompt}***\n***Instruction:{instruction:}***\n```python\n{chunk}\n```") + page = discord.Embed( + title=f"Page {count}", + description=chunk + if not codex + else f"***Prompt:{prompt}***\n***Instruction:{instruction:}***\n```python\n{chunk}\n```", + ) first = True else: - page = discord.Embed(title=f"Page {count}", description=chunk if not codex else f"```python\n{chunk}\n```") + page = discord.Embed( + title=f"Page {count}", + description=chunk if not codex else f"```python\n{chunk}\n```", + ) pages.append(page) - + return pages async def queue_debug_message(self, debug_message, debug_channel): @@ -511,11 +517,10 @@ async def on_message_edit(self, before, after): ).timestamp() await Moderation.moderation_queues[after.guild.id].put( Moderation(after, timestamp) - ) # TODO Don't proceed if message was deleted! + ) # TODO Don't proceed if message was deleted! await TextService.process_conversation_edit(self, after, original_message) - @discord.Cog.listener() async def on_message(self, message): if message.author == self.bot.user: @@ -534,11 +539,12 @@ async def on_message(self, message): ).timestamp() await Moderation.moderation_queues[message.guild.id].put( Moderation(message, timestamp) - ) # TODO Don't proceed to conversation processing if the message is deleted by moderations. - + ) # TODO Don't proceed to conversation processing if the message is deleted by moderations. # Process the message if the user is in a conversation - if await TextService.process_conversation_message(self, message, USER_INPUT_API_KEYS, USER_KEY_DB): + if await TextService.process_conversation_message( + self, message, USER_INPUT_API_KEYS, USER_KEY_DB + ): original_message[message.author.id] = message.id def cleanse_response(self, response_text): @@ -548,7 +554,9 @@ def cleanse_response(self, response_text): response_text = response_text.replace("<|endofstatement|>", "") return response_text - def remove_awaiting(self, author_id, channel_id, from_ask_command, from_edit_command): + def remove_awaiting( + self, author_id, channel_id, from_ask_command, from_edit_command + ): if author_id in self.awaiting_responses: self.awaiting_responses.remove(author_id) if not from_ask_command and not from_edit_command: @@ -569,7 +577,7 @@ async def mention_to_username(self, ctx, message): pass return message -# COMMANDS + # COMMANDS async def help_command(self, ctx): await ctx.defer() @@ -623,8 +631,9 @@ async def help_command(self, ctx): embed.add_field(name="/help", value="See this help text", inline=False) await ctx.respond(embed=embed) - - async def set_usage_command(self, ctx: discord.ApplicationContext, usage_amount: float): + async def set_usage_command( + self, ctx: discord.ApplicationContext, usage_amount: float + ): await ctx.defer() # Attempt to convert the input usage value into a float @@ -636,8 +645,9 @@ async def set_usage_command(self, ctx: discord.ApplicationContext, usage_amount: await ctx.respond("The usage value must be a valid float.") return - - async def delete_all_conversation_threads_command(self, ctx: discord.ApplicationContext): + async def delete_all_conversation_threads_command( + self, ctx: discord.ApplicationContext + ): await ctx.defer() for guild in self.bot.guilds: @@ -650,7 +660,6 @@ async def delete_all_conversation_threads_command(self, ctx: discord.Application pass await ctx.respond("All conversation threads have been deleted.") - async def usage_command(self, ctx): await ctx.defer() embed = discord.Embed( @@ -669,7 +678,6 @@ async def usage_command(self, ctx): ) await ctx.respond(embed=embed) - async def ask_command( self, ctx: discord.ApplicationContext, @@ -908,7 +916,6 @@ async def converse_command( if thread.id in self.awaiting_thread_responses: self.awaiting_thread_responses.remove(thread.id) - async def end_command(self, ctx: discord.ApplicationContext): await ctx.defer(ephemeral=True) user_id = ctx.user.id @@ -964,4 +971,3 @@ async def settings_command( # Otherwise, process the settings change await self.process_settings(ctx, parameter, value) - diff --git a/gpt3discord.py b/gpt3discord.py index 9d9e29ba..f8a1dc64 100644 --- a/gpt3discord.py +++ b/gpt3discord.py @@ -158,7 +158,6 @@ async def main(): ) ) - apply_multicog(bot) await bot.start(os.getenv("DISCORD_TOKEN")) diff --git a/services/image_service.py b/services/image_service.py index f2022508..524fd739 100644 --- a/services/image_service.py +++ b/services/image_service.py @@ -11,20 +11,19 @@ class ImageService: - def __init__(self): pass @staticmethod async def encapsulated_send( - image_service_cog, - user_id, - prompt, - ctx, - response_message=None, - vary=None, - draw_from_optimizer=None, - custom_api_key=None, + image_service_cog, + user_id, + prompt, + ctx, + response_message=None, + vary=None, + draw_from_optimizer=None, + custom_api_key=None, ): await asyncio.sleep(0) # send the prompt to the model @@ -93,7 +92,9 @@ async def encapsulated_send( ) image_service_cog.converser_cog.users_to_interactions[user_id] = [] - image_service_cog.converser_cog.users_to_interactions[user_id].append(result_message.id) + image_service_cog.converser_cog.users_to_interactions[user_id].append( + result_message.id + ) # Get the actual result message object if from_context: @@ -106,7 +107,7 @@ async def encapsulated_send( response=response_message, instruction=None, codex=False, - paginator=None + paginator=None, ) else: diff --git a/services/text_service.py b/services/text_service.py index ceb7d228..c6c64c57 100644 --- a/services/text_service.py +++ b/services/text_service.py @@ -53,13 +53,18 @@ async def encapsulated_send( try: # Pinecone is enabled, we will create embeddings for this conversation. - if converser_cog.pinecone_service and ctx.channel.id in converser_cog.conversation_threads: + if ( + converser_cog.pinecone_service + and ctx.channel.id in converser_cog.conversation_threads + ): # Delete "GPTie: <|endofstatement|>" from the user's conversation history if it exists # check if the text attribute for any object inside converser_cog.conversation_threads[converation_id].history # contains ""GPTie: <|endofstatement|>"", if so, delete for item in converser_cog.conversation_threads[ctx.channel.id].history: if item.text.strip() == "GPTie:<|endofstatement|>": - converser_cog.conversation_threads[ctx.channel.id].history.remove(item) + converser_cog.conversation_threads[ + ctx.channel.id + ].history.remove(item) # The conversation_id is the id of the thread conversation_id = ctx.channel.id @@ -133,12 +138,18 @@ async def encapsulated_send( for i in range( 1, min( - len(converser_cog.conversation_threads[ctx.channel.id].history), + len( + converser_cog.conversation_threads[ + ctx.channel.id + ].history + ), converser_cog.model.num_static_conversation_items, ), ): prompt_with_history.append( - converser_cog.conversation_threads[ctx.channel.id].history[-i] + converser_cog.conversation_threads[ctx.channel.id].history[ + -i + ] ) # remove duplicates from prompt_with_history and set the conversation history @@ -191,7 +202,9 @@ async def encapsulated_send( "".join( [ item.text - for item in converser_cog.conversation_threads[id].history + for item in converser_cog.conversation_threads[ + id + ].history ] ) + "\nGPTie: " @@ -238,12 +251,14 @@ async def encapsulated_send( ) # Clean the request response - response_text = converser_cog.cleanse_response(str(response["choices"][0]["text"])) + response_text = converser_cog.cleanse_response( + str(response["choices"][0]["text"]) + ) if from_ask_command: - # Append the prompt to the beginning of the response, in italics, then a new line - response_text = response_text.strip() - response_text = f"***{prompt}***\n\n{response_text}" + # Append the prompt to the beginning of the response, in italics, then a new line + response_text = response_text.strip() + response_text = f"***{prompt}***\n\n{response_text}" elif from_edit_command: if codex: response_text = response_text.strip() @@ -293,12 +308,14 @@ async def encapsulated_send( ) # Create and upsert the embedding for the conversation id, prompt, timestamp - embedding = await converser_cog.pinecone_service.upsert_conversation_embedding( - converser_cog.model, - conversation_id, - response_text, - timestamp, - custom_api_key=custom_api_key, + embedding = ( + await converser_cog.pinecone_service.upsert_conversation_embedding( + converser_cog.model, + conversation_id, + response_text, + timestamp, + custom_api_key=custom_api_key, + ) ) # Cleanse again @@ -307,7 +324,6 @@ async def encapsulated_send( # escape any other mentions like @here or @everyone response_text = discord.utils.escape_mentions(response_text) - # If we don't have a response message, we are not doing a redo, send as a new message(s) if not response_message: if len(response_text) > converser_cog.TEXT_CUTOFF: @@ -315,9 +331,21 @@ async def encapsulated_send( paginator = None await converser_cog.paginate_and_send(response_text, ctx) else: - embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction) - view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key) - paginator = pages.Paginator(pages=embed_pages, timeout=None, custom_view=view) + embed_pages = await converser_cog.paginate_embed( + response_text, codex, prompt, instruction + ) + view = ConversationView( + ctx, + converser_cog, + ctx.channel.id, + model, + from_ask_command, + from_edit_command, + custom_api_key=custom_api_key, + ) + paginator = pages.Paginator( + pages=embed_pages, timeout=None, custom_view=view + ) response_message = await paginator.respond(ctx.interaction) else: paginator = None @@ -336,24 +364,24 @@ async def encapsulated_send( response_message = await ctx.respond( response_text, view=ConversationView( - ctx, - converser_cog, - ctx.channel.id, - model, - from_edit_command=from_edit_command, - custom_api_key=custom_api_key + ctx, + converser_cog, + ctx.channel.id, + model, + from_edit_command=from_edit_command, + custom_api_key=custom_api_key, ), ) else: response_message = await ctx.respond( response_text, view=ConversationView( - ctx, - converser_cog, - ctx.channel.id, - model, - from_ask_command=from_ask_command, - custom_api_key=custom_api_key + ctx, + converser_cog, + ctx.channel.id, + model, + from_ask_command=from_ask_command, + custom_api_key=custom_api_key, ), ) @@ -366,13 +394,13 @@ async def encapsulated_send( ) converser_cog.redo_users[ctx.author.id] = RedoUser( - prompt=new_prompt, - instruction=instruction, - ctx=ctx, - message=ctx, - response=actual_response_message, - codex=codex, - paginator=paginator + prompt=new_prompt, + instruction=instruction, + ctx=ctx, + message=ctx, + response=actual_response_message, + codex=codex, + paginator=paginator, ) converser_cog.redo_users[ctx.author.id].add_interaction( actual_response_message.id @@ -382,20 +410,35 @@ async def encapsulated_send( else: paginator = converser_cog.redo_users.get(ctx.author.id).paginator if isinstance(paginator, pages.Paginator): - embed_pages = await converser_cog.paginate_embed(response_text, codex, prompt, instruction) - view=ConversationView(ctx, converser_cog, ctx.channel.id, model, from_ask_command, from_edit_command, custom_api_key=custom_api_key) + embed_pages = await converser_cog.paginate_embed( + response_text, codex, prompt, instruction + ) + view = ConversationView( + ctx, + converser_cog, + ctx.channel.id, + model, + from_ask_command, + from_edit_command, + custom_api_key=custom_api_key, + ) await paginator.update(pages=embed_pages, custom_view=view) elif len(response_text) > converser_cog.TEXT_CUTOFF: if not from_context: - await response_message.channel.send("Over 2000 characters", delete_after=5) + await response_message.channel.send( + "Over 2000 characters", delete_after=5 + ) else: await response_message.edit(content=response_text) await converser_cog.send_debug_message( - converser_cog.generate_debug_message(prompt, response), converser_cog.debug_channel + converser_cog.generate_debug_message(prompt, response), + converser_cog.debug_channel, ) - converser_cog.remove_awaiting(ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command) + converser_cog.remove_awaiting( + ctx.author.id, ctx.channel.id, from_ask_command, from_edit_command + ) # Error catching for AIOHTTP Errors except aiohttp.ClientResponseError as e: @@ -439,7 +482,9 @@ async def encapsulated_send( return @staticmethod - async def process_conversation_message(converser_cog, message, USER_INPUT_API_KEYS, USER_KEY_DB): + async def process_conversation_message( + converser_cog, message, USER_INPUT_API_KEYS, USER_KEY_DB + ): content = message.content.strip() conversing = converser_cog.check_conversing( message.author.id, message.channel.id, content @@ -506,7 +551,9 @@ async def process_conversation_message(converser_cog, message, USER_INPUT_API_KE converser_cog.awaiting_thread_responses.append(message.channel.id) if not converser_cog.pinecone_service: - converser_cog.conversation_threads[message.channel.id].history.append( + converser_cog.conversation_threads[ + message.channel.id + ].history.append( EmbeddedConversationItem( f"\n'{message.author.display_name}': {prompt} <|endofstatement|>\n", 0, @@ -519,8 +566,8 @@ async def process_conversation_message(converser_cog, message, USER_INPUT_API_KE # Send the request to the model # If conversing, the prompt to send is the history, otherwise, it's just the prompt if ( - converser_cog.pinecone_service - or message.channel.id not in converser_cog.conversation_threads + converser_cog.pinecone_service + or message.channel.id not in converser_cog.conversation_threads ): primary_prompt = prompt else: @@ -528,13 +575,15 @@ async def process_conversation_message(converser_cog, message, USER_INPUT_API_KE [ item.text for item in converser_cog.conversation_threads[ - message.channel.id - ].history + message.channel.id + ].history ] ) # set conversation overrides - overrides = converser_cog.conversation_threads[message.channel.id].get_overrides() + overrides = converser_cog.conversation_threads[ + message.channel.id + ].get_overrides() await TextService.encapsulated_send( converser_cog, @@ -554,7 +603,7 @@ async def process_conversation_message(converser_cog, message, USER_INPUT_API_KE async def get_user_api_key(user_id, ctx, USER_KEY_DB): user_api_key = None if user_id not in USER_KEY_DB else USER_KEY_DB[user_id] if user_api_key is None or user_api_key == "": - modal = SetupModal(title="API Key Setup",user_key_db=USER_KEY_DB) + modal = SetupModal(title="API Key Setup", user_key_db=USER_KEY_DB) if isinstance(ctx, discord.ApplicationContext): await ctx.send_modal(modal) await ctx.send_followup( @@ -574,17 +623,25 @@ async def process_conversation_edit(converser_cog, after, original_message): ctx = converser_cog.redo_users[after.author.id].ctx await response_message.edit(content="Redoing prompt 🔄...") - edited_content = await converser_cog.mention_to_username(after, after.content) + edited_content = await converser_cog.mention_to_username( + after, after.content + ) if after.channel.id in converser_cog.conversation_threads: # Remove the last two elements from the history array and add the new : prompt converser_cog.conversation_threads[ after.channel.id - ].history = converser_cog.conversation_threads[after.channel.id].history[:-2] + ].history = converser_cog.conversation_threads[ + after.channel.id + ].history[ + :-2 + ] pinecone_dont_reinsert = None if not converser_cog.pinecone_service: - converser_cog.conversation_threads[after.channel.id].history.append( + converser_cog.conversation_threads[ + after.channel.id + ].history.append( EmbeddedConversationItem( f"\n{after.author.display_name}: {after.content}<|endofstatement|>\n", 0, @@ -593,7 +650,9 @@ async def process_conversation_edit(converser_cog, after, original_message): converser_cog.conversation_threads[after.channel.id].count += 1 - overrides = converser_cog.conversation_threads[after.channel.id].get_overrides() + overrides = converser_cog.conversation_threads[ + after.channel.id + ].get_overrides() await TextService.encapsulated_send( converser_cog, @@ -616,6 +675,8 @@ async def process_conversation_edit(converser_cog, after, original_message): """ Conversation interaction buttons """ + + class ConversationView(discord.ui.View): def __init__( self, @@ -663,7 +724,11 @@ async def on_timeout(self): class EndConvoButton(discord.ui.Button["ConversationView"]): def __init__(self, converser_cog): - super().__init__(style=discord.ButtonStyle.danger, label="End Conversation", custom_id="conversation_end") + super().__init__( + style=discord.ButtonStyle.danger, + label="End Conversation", + custom_id="conversation_end", + ) self.converser_cog = converser_cog async def callback(self, interaction: discord.Interaction): @@ -693,8 +758,14 @@ async def callback(self, interaction: discord.Interaction): class RedoButton(discord.ui.Button["ConversationView"]): - def __init__(self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key): - super().__init__(style=discord.ButtonStyle.danger, label="Retry", custom_id="conversation_redo") + def __init__( + self, converser_cog, model, from_ask_command, from_edit_command, custom_api_key + ): + super().__init__( + style=discord.ButtonStyle.danger, + label="Retry", + custom_id="conversation_redo", + ) self.converser_cog = converser_cog self.model = model self.from_ask_command = from_ask_command @@ -744,6 +815,8 @@ async def callback(self, interaction: discord.Interaction): """ The setup modal when using user input API keys """ + + class SetupModal(discord.ui.Modal): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs)