From f22b9ae07f0b6d82fb144aedbb63a4cd225634e9 Mon Sep 17 00:00:00 2001 From: Nicolas Molina Date: Sun, 27 Oct 2024 01:53:14 -0400 Subject: [PATCH] chore(api): config sqlite history --- apps/api/src/chatbot/graph.ts | 4 +- apps/api/src/chatbot/routers/intent.router.ts | 4 +- apps/api/src/chatbot/saver.ts | 363 ++++++++++++++++++ 3 files changed, 366 insertions(+), 5 deletions(-) create mode 100644 apps/api/src/chatbot/saver.ts diff --git a/apps/api/src/chatbot/graph.ts b/apps/api/src/chatbot/graph.ts index ab7cdc1..8d7871a 100644 --- a/apps/api/src/chatbot/graph.ts +++ b/apps/api/src/chatbot/graph.ts @@ -1,5 +1,5 @@ import { END, START, StateGraph } from "@langchain/langgraph"; -import { SqliteSaver } from "@langchain/langgraph-checkpoint-sqlite"; +import { D1Saver } from "./saver"; import { GraphState } from "./graph.state"; @@ -21,7 +21,7 @@ interface Props { } export const createGraph = (data: Props, db: D1Database) => { - const memory = new SqliteSaver(db); + const memory = new D1Saver(db); const llmGpt4 = models.gpt4(data.openAIKey); const llmMistral = models.mistral(data.mistralKey); diff --git a/apps/api/src/chatbot/routers/intent.router.ts b/apps/api/src/chatbot/routers/intent.router.ts index 025b3ca..75c48c0 100644 --- a/apps/api/src/chatbot/routers/intent.router.ts +++ b/apps/api/src/chatbot/routers/intent.router.ts @@ -23,9 +23,7 @@ Otherwise, respond only with the word "Conversation".`; export const intentRouter = (llm: ChatMistralAI) => { return async (state: GraphState) => { - const { lastAgent } = state; - // TODO: - const isReadyToBook = false; + const { lastAgent, isReadyToBook } = state; if (lastAgent === MyNodes.AVAILABILITY || lastAgent === MyNodes.BOOKING) { return lastAgent; diff --git a/apps/api/src/chatbot/saver.ts b/apps/api/src/chatbot/saver.ts new file mode 100644 index 0000000..22be8b5 --- /dev/null +++ b/apps/api/src/chatbot/saver.ts @@ -0,0 +1,363 @@ +import type { RunnableConfig } from "@langchain/core/runnables"; +import { + BaseCheckpointSaver, + type Checkpoint, + type CheckpointListOptions, + type CheckpointTuple, + type SerializerProtocol, + type PendingWrite, + type CheckpointMetadata, +} from "@langchain/langgraph-checkpoint"; + +type DatabaseType = D1Database; + +interface CheckpointRow { + checkpoint: string; + metadata: string; + parent_checkpoint_id?: string; + thread_id: string; + checkpoint_id: string; + checkpoint_ns?: string; + type?: string; +} + +interface WritesRow { + thread_id: string; + checkpoint_ns: string; + checkpoint_id: string; + task_id: string; + idx: number; + channel: string; + type?: string; + value?: string; +} + +// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys +// that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list +// of keys that we use is out of sync with the `CheckpointMetadata` type. +const checkpointMetadataKeys = ["source", "step", "writes", "parents"] as const; + +type CheckKeys = [K[number]] extends [ + keyof T, +] + ? [keyof T] extends [K[number]] + ? K + : never + : never; + +function validateKeys( + keys: CheckKeys, +): K { + return keys; +} + +// If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the +// `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in +// `CheckpointMetadata` +const validCheckpointMetadataKeys = validateKeys< + CheckpointMetadata, + typeof checkpointMetadataKeys +>(checkpointMetadataKeys); + +export class D1Saver extends BaseCheckpointSaver { + db: DatabaseType; + + protected isSetup: boolean; + + constructor(db: DatabaseType, serde?: SerializerProtocol) { + super(serde); + this.db = db; + this.isSetup = false; + } + + protected setup(): void { + if (this.isSetup) { + return; + } + + this.db.exec(` +CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + type TEXT, + checkpoint BLOB, + metadata BLOB, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) +);`); + this.db.exec(` +CREATE TABLE IF NOT EXISTS writes ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + value BLOB, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) +);`); + + this.isSetup = true; + } + + async getTuple(config: RunnableConfig): Promise { + this.setup(); + const { + thread_id, + checkpoint_ns = "", + checkpoint_id, + } = config.configurable ?? {}; + let row: CheckpointRow | null; + if (checkpoint_id) { + row = await this.db + .prepare( + `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`, + ) + .bind(thread_id, checkpoint_ns, checkpoint_id) + .first(); + } else { + row = await this.db + .prepare( + `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1`, + ) + .bind(thread_id, checkpoint_ns) + .first(); + } + if (row === null) { + return undefined; + } + let finalConfig = config; + if (!checkpoint_id) { + finalConfig = { + configurable: { + thread_id: row.thread_id, + checkpoint_ns, + checkpoint_id: row.checkpoint_id, + }, + }; + } + if ( + finalConfig.configurable?.thread_id === undefined || + finalConfig.configurable?.checkpoint_id === undefined + ) { + throw new Error("Missing thread_id or checkpoint_id"); + } + // find any pending writes + const pendingWritesRows = await this.db + .prepare( + `SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`, + ) + .bind( + finalConfig.configurable.thread_id.toString(), + checkpoint_ns, + finalConfig.configurable.checkpoint_id.toString(), + ) + .all(); + const pendingWrites = await Promise.all( + pendingWritesRows.results.map(async (row) => { + return [ + row.task_id, + row.channel, + await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""), + ] as [string, string, unknown]; + }), + ); + return { + config: finalConfig, + checkpoint: (await this.serde.loadsTyped( + row.type ?? "json", + row.checkpoint, + )) as Checkpoint, + metadata: (await this.serde.loadsTyped( + row.type ?? "json", + row.metadata, + )) as CheckpointMetadata, + parentConfig: row.parent_checkpoint_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } + : undefined, + pendingWrites, + }; + } + + async *list( + config: RunnableConfig, + options?: CheckpointListOptions, + ): AsyncGenerator { + const { limit, before, filter } = options ?? {}; + this.setup(); + const thread_id = config.configurable?.thread_id; + const checkpoint_ns = config.configurable?.checkpoint_ns; + + let sql = + `SELECT\n` + + " thread_id,\n" + + " checkpoint_ns,\n" + + " checkpoint_id,\n" + + " parent_checkpoint_id,\n" + + " type,\n" + + " checkpoint,\n" + + " metadata\n" + + "FROM checkpoints\n"; + + const whereClause: string[] = []; + + if (thread_id) { + whereClause.push("thread_id = ?"); + } + + if (checkpoint_ns !== undefined && checkpoint_ns !== null) { + whereClause.push("checkpoint_ns = ?"); + } + + if (before?.configurable?.checkpoint_id !== undefined) { + whereClause.push("checkpoint_id < ?"); + } + + const sanitizedFilter = Object.fromEntries( + Object.entries(filter ?? {}).filter( + ([key, value]) => + value !== undefined && + validCheckpointMetadataKeys.includes(key as keyof CheckpointMetadata), + ), + ); + + whereClause.push( + ...Object.entries(sanitizedFilter).map( + ([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?`, + ), + ); + + if (whereClause.length > 0) { + sql += `WHERE\n ${whereClause.join(" AND\n ")}\n`; + } + + sql += "\nORDER BY checkpoint_id DESC"; + + if (limit) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided + } + + const args = [ + thread_id, + checkpoint_ns, + before?.configurable?.checkpoint_id, + ...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)), + ].filter((value) => value !== undefined && value !== null); + + const rows = await this.db + .prepare(sql) + .bind(...args) + .all(); + + if (rows) { + for (const row of rows.results) { + yield { + config: { + configurable: { + thread_id: row.thread_id, + checkpoint_ns: row.checkpoint_ns, + checkpoint_id: row.checkpoint_id, + }, + }, + checkpoint: (await this.serde.loadsTyped( + row.type ?? "json", + row.checkpoint, + )) as Checkpoint, + metadata: (await this.serde.loadsTyped( + row.type ?? "json", + row.metadata, + )) as CheckpointMetadata, + parentConfig: row.parent_checkpoint_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_ns: row.checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } + : undefined, + }; + } + } + } + + async put( + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + ): Promise { + this.setup(); + + const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); + const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); + if (type1 !== type2) { + throw new Error( + "Failed to serialized checkpoint and metadata to the same type.", + ); + } + const row = [ + config.configurable?.thread_id?.toString(), + config.configurable?.checkpoint_ns, + checkpoint.id, + config.configurable?.checkpoint_id, + type1, + serializedCheckpoint, + serializedMetadata, + ]; + + await this.db + .prepare( + `INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)`, + ) + .bind(...row) + .all(); + + return { + configurable: { + thread_id: config.configurable?.thread_id, + checkpoint_ns: config.configurable?.checkpoint_ns, + checkpoint_id: checkpoint.id, + }, + }; + } + + async putWrites( + config: RunnableConfig, + writes: PendingWrite[], + taskId: string, + ): Promise { + this.setup(); + + const rows = writes.map((write, idx) => { + const [type, serializedWrite] = this.serde.dumpsTyped(write[1]); + const args = [ + config.configurable?.thread_id, + config.configurable?.checkpoint_ns, + config.configurable?.checkpoint_id, + taskId, + idx, + write[0], + type, + serializedWrite, + ]; + return this.db + .prepare(` + INSERT OR REPLACE INTO writes + (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + `) + .bind(...args); + }); + + await this.db.batch(rows); + } +}