diff --git a/src/pages/Options.tsx b/src/pages/Options.tsx index 983293c..a4e421d 100644 --- a/src/pages/Options.tsx +++ b/src/pages/Options.tsx @@ -45,7 +45,7 @@ export const EMBEDDING_MODELS = ["nomic-embed-text", "all-minilm"]; export const CHAT_CONTAINER_HEIGHT_MIN = 200; export const CHAT_CONTAINER_HEIGHT_MAX = 500; -interface LumosOptions { +export interface LumosOptions { ollamaModel: string; ollamaEmbeddingModel: string; ollamaHost: string; diff --git a/src/scripts/background.ts b/src/scripts/background.ts index ccfb001..a6be561 100644 --- a/src/scripts/background.ts +++ b/src/scripts/background.ts @@ -4,7 +4,6 @@ import { StringOutputParser } from "@langchain/core/output_parsers"; import { ChatPromptTemplate, SystemMessagePromptTemplate, - BaseMessagePromptTemplateLike, } from "@langchain/core/prompts"; import { RunnableSequence } from "@langchain/core/runnables"; import { ConsoleCallbackHandler } from "@langchain/core/tracers/console"; @@ -25,6 +24,7 @@ import { DEFAULT_KEEP_ALIVE, getLumosOptions, isMultimodal, + LumosOptions, } from "../pages/Options"; interface VectorStoreMetadata { @@ -45,6 +45,8 @@ const CLS_IMG_PROMPT = "Is the following prompt referring to an image or asking to describe an image?"; const CLS_IMG_TRIGGER = "based on the image"; +const SYS_PROMPT_TEMPLATE = `Use the following context and the chat history when responding to the prompt.\n\nBEGIN CONTEXT\n\n{filtered_context}\n\nEND CONTEXT`; + function sleep(ms: number) { return new Promise((resolve) => setTimeout(resolve, ms)); } @@ -92,13 +94,22 @@ const classifyPrompt = async ( }); }; +const getChatModel = (options: LumosOptions): ChatOllama => { + return new ChatOllama({ + baseUrl: options.ollamaHost, + model: options.ollamaModel, + keepAlive: DEFAULT_KEEP_ALIVE, + callbacks: [new ConsoleCallbackHandler()], + }); +}; + const getMessages = async (): Promise => { // the array of persisted messages includes the current prompt const data = await chrome.storage.session.get(["messages"]); if (data.messages) { const msgs = data.messages as LumosMessage[]; - return msgs.slice(-10).map((msg: LumosMessage) => { + return msgs.slice(-5).map((msg: LumosMessage) => { return msg.sender === "user" ? new HumanMessage({ content: msg.message, @@ -171,18 +182,9 @@ chrome.runtime.onMessage.addListener(async (request) => { return executeCalculatorTool(prompt); } - // create prompt - const chatPrompt = ChatPromptTemplate.fromMessages(await getMessages()); - - // create model - const model = new ChatOllama({ - baseUrl: options.ollamaHost, - model: options.ollamaModel, - keepAlive: DEFAULT_KEEP_ALIVE, - callbacks: [new ConsoleCallbackHandler()], - }); - // create chain + const chatPrompt = ChatPromptTemplate.fromMessages(await getMessages()); + const model = getChatModel(options); const chain = chatPrompt.pipe(model).pipe(new StringOutputParser()); // stream response chunks @@ -282,22 +284,6 @@ chrome.runtime.onMessage.addListener(async (request) => { return executeCalculatorTool(prompt); } - const template = `Use only the following context when responding to the prompt. Don't use any other knowledge.\n\nBEGIN CONTEXT\n\n{filtered_context}\n\nEND CONTEXT`; - let messages: BaseMessagePromptTemplateLike[] = [ - SystemMessagePromptTemplate.fromTemplate(template), - ]; - messages = messages.concat(await getMessages()); - // TODO: append image data in new message - const chatPrompt = ChatPromptTemplate.fromMessages(messages); - - // create model and bind base64 encoded image data - const model = new ChatOllama({ - baseUrl: options.ollamaHost, - model: options.ollamaModel, - keepAlive: DEFAULT_KEEP_ALIVE, - callbacks: [new ConsoleCallbackHandler()], - }); - // check if vector store already exists for url let vectorStore: EnhancedMemoryVectorStore; let documentsCount: number; @@ -359,13 +345,19 @@ chrome.runtime.onMessage.addListener(async (request) => { } } + // create chain const retriever = vectorStore.asRetriever({ k: computeK(documentsCount), searchType: "hybrid", callbacks: [new ConsoleCallbackHandler()], }); - // create chain + const chatPrompt = ChatPromptTemplate.fromMessages([ + SystemMessagePromptTemplate.fromTemplate(SYS_PROMPT_TEMPLATE), + ...(await getMessages()), + ]); + + const model = getChatModel(options); const chain = RunnableSequence.from([ { filtered_context: retriever.pipe(formatDocumentsAsString), @@ -376,7 +368,7 @@ chrome.runtime.onMessage.addListener(async (request) => { ]); // stream response chunks - const stream = await chain.stream(""); + const stream = await chain.stream(prompt); streamChunks(stream); }