Skip to content
This repository has been archived by the owner on Jun 26, 2024. It is now read-only.

Commit

Permalink
add discord chatbot (#15)
Browse files Browse the repository at this point in the history
* add discord chatbot

* update
  • Loading branch information
aniketmaurya authored Jul 16, 2023
1 parent 628d7fc commit 1b790fd
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 7 deletions.
8 changes: 6 additions & 2 deletions examples/chatbot/chatbot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@
"metadata": {},
"outputs": [],
"source": [
"output = conversation(\"PyTorch Lightning is an open-source library developed by Lightning AI team.\")[\"response\"]\n",
"output = conversation(\n",
" \"PyTorch Lightning is an open-source library developed by Lightning AI team.\"\n",
")[\"response\"]\n",
"print(output)"
]
},
Expand All @@ -150,7 +152,9 @@
"metadata": {},
"outputs": [],
"source": [
"output = conversation(\"who developed PyTorch Lightning? just give me the name of the team or person and nothing else.\")[\"response\"]\n",
"output = conversation(\n",
" \"who developed PyTorch Lightning? just give me the name of the team or person and nothing else.\"\n",
")[\"response\"]\n",
"print(output)"
]
},
Expand Down
53 changes: 53 additions & 0 deletions examples/chatbot/discord_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# pip install discord.py
# Learn more here - https://github.com/aniketmaurya/docs-QnA-discord-bot/tree/main
import os

import discord
from dotenv import load_dotenv

from llm_chain import LitGPTConversationChain, LitGPTLLM
from llm_chain.templates import longchat_prompt_template
from llm_inference import prepare_weights

load_dotenv()

# path = prepare_weights("lmsys/longchat-7b-16k")
path = "checkpoints/lmsys/longchat-13b-16k"
llm = LitGPTLLM(checkpoint_dir=path, quantize="bnb.nf4")
llm("warm up!")
TOKEN = os.environ.get("DISCORD_BOT_TOKEN")


class MyClient(discord.Client):
BOT_INSTANCE = {}

def chat(self, user_id, query):
if user_id in self.BOT_INSTANCE:
return self.BOT_INSTANCE[user_id].send(query)

self.BOT_INSTANCE[user_id] = LitGPTConversationChain.from_llm(
llm=llm, prompt=longchat_prompt_template
)
return self.BOT_INSTANCE[user_id].send(query)

bot = LitGPTConversationChain.from_llm(llm=llm, prompt=longchat_prompt_template)

async def on_ready(self):
print(f"Logged on as {self.user}!")

async def on_message(self, message):
if message.author.id == self.user.id:
return
print(f"Message from {message.author}: {message.content}")

if message.content.startswith("!help"):
query = message.content.replace("!help", "")
result = self.bot.send(query)
await message.reply(result, mention_author=True)


intents = discord.Intents.default()
intents.message_content = True

client = MyClient(intents=intents)
client.run(TOKEN)
2 changes: 1 addition & 1 deletion examples/chatbot/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def respond(message, chat_history):
bot_message = bot.send(message)
chat_history.append((f"👤 {message}", f"🤖 {bot_message}"))
chat_history.append((f"👤 {message}", f"{bot_message}"))
return "", chat_history

msg.submit(respond, [msg, chatbot], [msg, chatbot])
Expand Down
35 changes: 31 additions & 4 deletions src/llm_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,41 @@ def chat(
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.1,
eos_id=None,
) -> str:
output = self.__call__(
prompt=prompt,
max_new_tokens=max_new_tokens,
top_k=top_k,
tokenizer = self.tokenizer
model = self.model
fabric = self.fabric

encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = model.config.block_size

t0 = time.perf_counter()
y = _generate(
model,
encoded,
max_returned_tokens,
max_seq_length=max_returned_tokens,
temperature=temperature,
top_k=top_k,
eos_id=self.tokenizer.eos_id,
)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y[prompt_length:])
tokens_generated = y.size(0) - prompt_length
fabric.print(
f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec",
file=sys.stderr,
)
if fabric.device.type == "cuda":
fabric.print(
f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB",
file=sys.stderr,
)

return output

def eval(self):
Expand Down

0 comments on commit 1b790fd

Please sign in to comment.