Skip to content

Commit

Permalink
chore(api): config sqlite history
Browse files Browse the repository at this point in the history
  • Loading branch information
nicobytes committed Oct 27, 2024
1 parent 6cc67ea commit f22b9ae
Show file tree
Hide file tree
Showing 3 changed files with 366 additions and 5 deletions.
4 changes: 2 additions & 2 deletions apps/api/src/chatbot/graph.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { END, START, StateGraph } from "@langchain/langgraph";
import { SqliteSaver } from "@langchain/langgraph-checkpoint-sqlite";
import { D1Saver } from "./saver";

import { GraphState } from "./graph.state";

Expand All @@ -21,7 +21,7 @@ interface Props {
}

export const createGraph = (data: Props, db: D1Database) => {
const memory = new SqliteSaver(db);
const memory = new D1Saver(db);

const llmGpt4 = models.gpt4(data.openAIKey);
const llmMistral = models.mistral(data.mistralKey);
Expand Down
4 changes: 1 addition & 3 deletions apps/api/src/chatbot/routers/intent.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ Otherwise, respond only with the word "Conversation".`;

export const intentRouter = (llm: ChatMistralAI) => {
return async (state: GraphState) => {
const { lastAgent } = state;
// TODO:
const isReadyToBook = false;
const { lastAgent, isReadyToBook } = state;

if (lastAgent === MyNodes.AVAILABILITY || lastAgent === MyNodes.BOOKING) {
return lastAgent;
Expand Down
363 changes: 363 additions & 0 deletions apps/api/src/chatbot/saver.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,363 @@
import type { RunnableConfig } from "@langchain/core/runnables";
import {
BaseCheckpointSaver,
type Checkpoint,
type CheckpointListOptions,
type CheckpointTuple,
type SerializerProtocol,
type PendingWrite,
type CheckpointMetadata,
} from "@langchain/langgraph-checkpoint";

type DatabaseType = D1Database;

interface CheckpointRow {
checkpoint: string;
metadata: string;
parent_checkpoint_id?: string;
thread_id: string;
checkpoint_id: string;
checkpoint_ns?: string;
type?: string;
}

interface WritesRow {
thread_id: string;
checkpoint_ns: string;
checkpoint_id: string;
task_id: string;
idx: number;
channel: string;
type?: string;
value?: string;
}

// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys
// that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list
// of keys that we use is out of sync with the `CheckpointMetadata` type.
const checkpointMetadataKeys = ["source", "step", "writes", "parents"] as const;

type CheckKeys<T, K extends readonly (keyof T)[]> = [K[number]] extends [
keyof T,
]
? [keyof T] extends [K[number]]
? K
: never
: never;

function validateKeys<T, K extends readonly (keyof T)[]>(
keys: CheckKeys<T, K>,
): K {
return keys;
}

// If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the
// `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in
// `CheckpointMetadata`
const validCheckpointMetadataKeys = validateKeys<
CheckpointMetadata,
typeof checkpointMetadataKeys
>(checkpointMetadataKeys);

export class D1Saver extends BaseCheckpointSaver {
db: DatabaseType;

protected isSetup: boolean;

constructor(db: DatabaseType, serde?: SerializerProtocol) {
super(serde);
this.db = db;
this.isSetup = false;
}

protected setup(): void {
if (this.isSetup) {
return;
}

this.db.exec(`
CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);`);
this.db.exec(`
CREATE TABLE IF NOT EXISTS writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
value BLOB,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);`);

this.isSetup = true;
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
this.setup();
const {
thread_id,
checkpoint_ns = "",
checkpoint_id,
} = config.configurable ?? {};
let row: CheckpointRow | null;
if (checkpoint_id) {
row = await this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`,
)
.bind(thread_id, checkpoint_ns, checkpoint_id)
.first<CheckpointRow>();
} else {
row = await this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1`,
)
.bind(thread_id, checkpoint_ns)
.first<CheckpointRow>();
}
if (row === null) {
return undefined;
}
let finalConfig = config;
if (!checkpoint_id) {
finalConfig = {
configurable: {
thread_id: row.thread_id,
checkpoint_ns,
checkpoint_id: row.checkpoint_id,
},
};
}
if (
finalConfig.configurable?.thread_id === undefined ||
finalConfig.configurable?.checkpoint_id === undefined
) {
throw new Error("Missing thread_id or checkpoint_id");
}
// find any pending writes
const pendingWritesRows = await this.db
.prepare(
`SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`,
)
.bind(
finalConfig.configurable.thread_id.toString(),
checkpoint_ns,
finalConfig.configurable.checkpoint_id.toString(),
)
.all<WritesRow>();
const pendingWrites = await Promise.all(
pendingWritesRows.results.map(async (row) => {
return [
row.task_id,
row.channel,
await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""),
] as [string, string, unknown];
}),
);
return {
config: finalConfig,
checkpoint: (await this.serde.loadsTyped(
row.type ?? "json",
row.checkpoint,
)) as Checkpoint,
metadata: (await this.serde.loadsTyped(
row.type ?? "json",
row.metadata,
)) as CheckpointMetadata,
parentConfig: row.parent_checkpoint_id
? {
configurable: {
thread_id: row.thread_id,
checkpoint_ns,
checkpoint_id: row.parent_checkpoint_id,
},
}
: undefined,
pendingWrites,
};
}

async *list(
config: RunnableConfig,
options?: CheckpointListOptions,
): AsyncGenerator<CheckpointTuple> {
const { limit, before, filter } = options ?? {};
this.setup();
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns;

let sql =
`SELECT\n` +
" thread_id,\n" +
" checkpoint_ns,\n" +
" checkpoint_id,\n" +
" parent_checkpoint_id,\n" +
" type,\n" +
" checkpoint,\n" +
" metadata\n" +
"FROM checkpoints\n";

const whereClause: string[] = [];

if (thread_id) {
whereClause.push("thread_id = ?");
}

if (checkpoint_ns !== undefined && checkpoint_ns !== null) {
whereClause.push("checkpoint_ns = ?");
}

if (before?.configurable?.checkpoint_id !== undefined) {
whereClause.push("checkpoint_id < ?");
}

const sanitizedFilter = Object.fromEntries(
Object.entries(filter ?? {}).filter(
([key, value]) =>
value !== undefined &&
validCheckpointMetadataKeys.includes(key as keyof CheckpointMetadata),
),
);

whereClause.push(
...Object.entries(sanitizedFilter).map(
([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?`,
),
);

if (whereClause.length > 0) {
sql += `WHERE\n ${whereClause.join(" AND\n ")}\n`;
}

sql += "\nORDER BY checkpoint_id DESC";

if (limit) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided
}

const args = [
thread_id,
checkpoint_ns,
before?.configurable?.checkpoint_id,
...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)),
].filter((value) => value !== undefined && value !== null);

const rows = await this.db
.prepare(sql)
.bind(...args)
.all<CheckpointRow>();

if (rows) {
for (const row of rows.results) {
yield {
config: {
configurable: {
thread_id: row.thread_id,
checkpoint_ns: row.checkpoint_ns,
checkpoint_id: row.checkpoint_id,
},
},
checkpoint: (await this.serde.loadsTyped(
row.type ?? "json",
row.checkpoint,
)) as Checkpoint,
metadata: (await this.serde.loadsTyped(
row.type ?? "json",
row.metadata,
)) as CheckpointMetadata,
parentConfig: row.parent_checkpoint_id
? {
configurable: {
thread_id: row.thread_id,
checkpoint_ns: row.checkpoint_ns,
checkpoint_id: row.parent_checkpoint_id,
},
}
: undefined,
};
}
}
}

async put(
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
): Promise<RunnableConfig> {
this.setup();

const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint);
const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata);
if (type1 !== type2) {
throw new Error(
"Failed to serialized checkpoint and metadata to the same type.",
);
}
const row = [
config.configurable?.thread_id?.toString(),
config.configurable?.checkpoint_ns,
checkpoint.id,
config.configurable?.checkpoint_id,
type1,
serializedCheckpoint,
serializedMetadata,
];

await this.db
.prepare(
`INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)`,
)
.bind(...row)
.all();

return {
configurable: {
thread_id: config.configurable?.thread_id,
checkpoint_ns: config.configurable?.checkpoint_ns,
checkpoint_id: checkpoint.id,
},
};
}

async putWrites(
config: RunnableConfig,
writes: PendingWrite[],
taskId: string,
): Promise<void> {
this.setup();

const rows = writes.map((write, idx) => {
const [type, serializedWrite] = this.serde.dumpsTyped(write[1]);
const args = [
config.configurable?.thread_id,
config.configurable?.checkpoint_ns,
config.configurable?.checkpoint_id,
taskId,
idx,
write[0],
type,
serializedWrite,
];
return this.db
.prepare(`
INSERT OR REPLACE INTO writes
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`)
.bind(...args);
});

await this.db.batch(rows);
}
}

0 comments on commit f22b9ae

Please sign in to comment.