Skip to content

Commit

Permalink
Implement chat for RAG workflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewnguonly committed Mar 13, 2024
1 parent eebb90d commit 4910e5e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/pages/Options.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export const EMBEDDING_MODELS = ["nomic-embed-text", "all-minilm"];
export const CHAT_CONTAINER_HEIGHT_MIN = 200;
export const CHAT_CONTAINER_HEIGHT_MAX = 500;

interface LumosOptions {
export interface LumosOptions {
ollamaModel: string;
ollamaEmbeddingModel: string;
ollamaHost: string;
Expand Down
54 changes: 23 additions & 31 deletions src/scripts/background.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { StringOutputParser } from "@langchain/core/output_parsers";
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
BaseMessagePromptTemplateLike,
} from "@langchain/core/prompts";
import { RunnableSequence } from "@langchain/core/runnables";
import { ConsoleCallbackHandler } from "@langchain/core/tracers/console";
Expand All @@ -25,6 +24,7 @@ import {
DEFAULT_KEEP_ALIVE,
getLumosOptions,
isMultimodal,
LumosOptions,
} from "../pages/Options";

interface VectorStoreMetadata {
Expand All @@ -45,6 +45,8 @@ const CLS_IMG_PROMPT =
"Is the following prompt referring to an image or asking to describe an image?";
const CLS_IMG_TRIGGER = "based on the image";

const SYS_PROMPT_TEMPLATE = `Use the following context and the chat history when responding to the prompt.\n\nBEGIN CONTEXT\n\n{filtered_context}\n\nEND CONTEXT`;

function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
Expand Down Expand Up @@ -92,13 +94,22 @@ const classifyPrompt = async (
});
};

const getChatModel = (options: LumosOptions): ChatOllama => {
return new ChatOllama({
baseUrl: options.ollamaHost,
model: options.ollamaModel,
keepAlive: DEFAULT_KEEP_ALIVE,
callbacks: [new ConsoleCallbackHandler()],
});
};

const getMessages = async (): Promise<BaseMessage[]> => {
// the array of persisted messages includes the current prompt
const data = await chrome.storage.session.get(["messages"]);

if (data.messages) {
const msgs = data.messages as LumosMessage[];
return msgs.slice(-10).map((msg: LumosMessage) => {
return msgs.slice(-5).map((msg: LumosMessage) => {
return msg.sender === "user"
? new HumanMessage({
content: msg.message,
Expand Down Expand Up @@ -171,18 +182,9 @@ chrome.runtime.onMessage.addListener(async (request) => {
return executeCalculatorTool(prompt);
}

// create prompt
const chatPrompt = ChatPromptTemplate.fromMessages(await getMessages());

// create model
const model = new ChatOllama({
baseUrl: options.ollamaHost,
model: options.ollamaModel,
keepAlive: DEFAULT_KEEP_ALIVE,
callbacks: [new ConsoleCallbackHandler()],
});

// create chain
const chatPrompt = ChatPromptTemplate.fromMessages(await getMessages());
const model = getChatModel(options);
const chain = chatPrompt.pipe(model).pipe(new StringOutputParser());

// stream response chunks
Expand Down Expand Up @@ -282,22 +284,6 @@ chrome.runtime.onMessage.addListener(async (request) => {
return executeCalculatorTool(prompt);
}

const template = `Use only the following context when responding to the prompt. Don't use any other knowledge.\n\nBEGIN CONTEXT\n\n{filtered_context}\n\nEND CONTEXT`;
let messages: BaseMessagePromptTemplateLike[] = [
SystemMessagePromptTemplate.fromTemplate(template),
];
messages = messages.concat(await getMessages());
// TODO: append image data in new message
const chatPrompt = ChatPromptTemplate.fromMessages(messages);

// create model and bind base64 encoded image data
const model = new ChatOllama({
baseUrl: options.ollamaHost,
model: options.ollamaModel,
keepAlive: DEFAULT_KEEP_ALIVE,
callbacks: [new ConsoleCallbackHandler()],
});

// check if vector store already exists for url
let vectorStore: EnhancedMemoryVectorStore;
let documentsCount: number;
Expand Down Expand Up @@ -359,13 +345,19 @@ chrome.runtime.onMessage.addListener(async (request) => {
}
}

// create chain
const retriever = vectorStore.asRetriever({
k: computeK(documentsCount),
searchType: "hybrid",
callbacks: [new ConsoleCallbackHandler()],
});

// create chain
const chatPrompt = ChatPromptTemplate.fromMessages([
SystemMessagePromptTemplate.fromTemplate(SYS_PROMPT_TEMPLATE),
...(await getMessages()),
]);

const model = getChatModel(options);
const chain = RunnableSequence.from([
{
filtered_context: retriever.pipe(formatDocumentsAsString),
Expand All @@ -376,7 +368,7 @@ chrome.runtime.onMessage.addListener(async (request) => {
]);

// stream response chunks
const stream = await chain.stream("");
const stream = await chain.stream(prompt);
streamChunks(stream);
}

Expand Down

0 comments on commit 4910e5e

Please sign in to comment.