diff --git a/examples/chatbot/chatbot.ipynb b/examples/chatbot/chatbot.ipynb index f4eb268..8d08390 100644 --- a/examples/chatbot/chatbot.ipynb +++ b/examples/chatbot/chatbot.ipynb @@ -140,7 +140,9 @@ "metadata": {}, "outputs": [], "source": [ - "output = conversation(\"PyTorch Lightning is an open-source library developed by Lightning AI team.\")[\"response\"]\n", + "output = conversation(\n", + " \"PyTorch Lightning is an open-source library developed by Lightning AI team.\"\n", + ")[\"response\"]\n", "print(output)" ] }, @@ -150,7 +152,9 @@ "metadata": {}, "outputs": [], "source": [ - "output = conversation(\"who developed PyTorch Lightning? just give me the name of the team or person and nothing else.\")[\"response\"]\n", + "output = conversation(\n", + " \"who developed PyTorch Lightning? just give me the name of the team or person and nothing else.\"\n", + ")[\"response\"]\n", "print(output)" ] }, diff --git a/examples/chatbot/discord_bot.py b/examples/chatbot/discord_bot.py new file mode 100644 index 0000000..8c9e4f7 --- /dev/null +++ b/examples/chatbot/discord_bot.py @@ -0,0 +1,53 @@ +# pip install discord.py +# Learn more here - https://github.com/aniketmaurya/docs-QnA-discord-bot/tree/main +import os + +import discord +from dotenv import load_dotenv + +from llm_chain import LitGPTConversationChain, LitGPTLLM +from llm_chain.templates import longchat_prompt_template +from llm_inference import prepare_weights + +load_dotenv() + +# path = prepare_weights("lmsys/longchat-7b-16k") +path = "checkpoints/lmsys/longchat-13b-16k" +llm = LitGPTLLM(checkpoint_dir=path, quantize="bnb.nf4") +llm("warm up!") +TOKEN = os.environ.get("DISCORD_BOT_TOKEN") + + +class MyClient(discord.Client): + BOT_INSTANCE = {} + + def chat(self, user_id, query): + if user_id in self.BOT_INSTANCE: + return self.BOT_INSTANCE[user_id].send(query) + + self.BOT_INSTANCE[user_id] = LitGPTConversationChain.from_llm( + llm=llm, prompt=longchat_prompt_template + ) + return self.BOT_INSTANCE[user_id].send(query) + + bot = LitGPTConversationChain.from_llm(llm=llm, prompt=longchat_prompt_template) + + async def on_ready(self): + print(f"Logged on as {self.user}!") + + async def on_message(self, message): + if message.author.id == self.user.id: + return + print(f"Message from {message.author}: {message.content}") + + if message.content.startswith("!help"): + query = message.content.replace("!help", "") + result = self.bot.send(query) + await message.reply(result, mention_author=True) + + +intents = discord.Intents.default() +intents.message_content = True + +client = MyClient(intents=intents) +client.run(TOKEN) diff --git a/examples/chatbot/gradio_demo.py b/examples/chatbot/gradio_demo.py index e22737a..930d6e8 100644 --- a/examples/chatbot/gradio_demo.py +++ b/examples/chatbot/gradio_demo.py @@ -16,7 +16,7 @@ def respond(message, chat_history): bot_message = bot.send(message) - chat_history.append((f"👤 {message}", f"🤖 {bot_message}")) + chat_history.append((f"👤 {message}", f"{bot_message}")) return "", chat_history msg.submit(respond, [msg, chatbot], [msg, chatbot]) diff --git a/src/llm_inference/model.py b/src/llm_inference/model.py index c066c9f..8e709dd 100644 --- a/src/llm_inference/model.py +++ b/src/llm_inference/model.py @@ -236,14 +236,41 @@ def chat( max_new_tokens: int = 100, top_k: int = 200, temperature: float = 0.1, + eos_id=None, ) -> str: - output = self.__call__( - prompt=prompt, - max_new_tokens=max_new_tokens, - top_k=top_k, + tokenizer = self.tokenizer + model = self.model + fabric = self.fabric + + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = model.config.block_size + + t0 = time.perf_counter() + y = _generate( + model, + encoded, + max_returned_tokens, + max_seq_length=max_returned_tokens, temperature=temperature, + top_k=top_k, eos_id=self.tokenizer.eos_id, ) + t = time.perf_counter() - t0 + + model.reset_cache() + output = tokenizer.decode(y[prompt_length:]) + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + return output def eval(self):