-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
366 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,363 @@ | ||
import type { RunnableConfig } from "@langchain/core/runnables"; | ||
import { | ||
BaseCheckpointSaver, | ||
type Checkpoint, | ||
type CheckpointListOptions, | ||
type CheckpointTuple, | ||
type SerializerProtocol, | ||
type PendingWrite, | ||
type CheckpointMetadata, | ||
} from "@langchain/langgraph-checkpoint"; | ||
|
||
type DatabaseType = D1Database; | ||
|
||
interface CheckpointRow { | ||
checkpoint: string; | ||
metadata: string; | ||
parent_checkpoint_id?: string; | ||
thread_id: string; | ||
checkpoint_id: string; | ||
checkpoint_ns?: string; | ||
type?: string; | ||
} | ||
|
||
interface WritesRow { | ||
thread_id: string; | ||
checkpoint_ns: string; | ||
checkpoint_id: string; | ||
task_id: string; | ||
idx: number; | ||
channel: string; | ||
type?: string; | ||
value?: string; | ||
} | ||
|
||
// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys | ||
// that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list | ||
// of keys that we use is out of sync with the `CheckpointMetadata` type. | ||
const checkpointMetadataKeys = ["source", "step", "writes", "parents"] as const; | ||
|
||
type CheckKeys<T, K extends readonly (keyof T)[]> = [K[number]] extends [ | ||
keyof T, | ||
] | ||
? [keyof T] extends [K[number]] | ||
? K | ||
: never | ||
: never; | ||
|
||
function validateKeys<T, K extends readonly (keyof T)[]>( | ||
keys: CheckKeys<T, K>, | ||
): K { | ||
return keys; | ||
} | ||
|
||
// If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the | ||
// `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in | ||
// `CheckpointMetadata` | ||
const validCheckpointMetadataKeys = validateKeys< | ||
CheckpointMetadata, | ||
typeof checkpointMetadataKeys | ||
>(checkpointMetadataKeys); | ||
|
||
export class D1Saver extends BaseCheckpointSaver { | ||
db: DatabaseType; | ||
|
||
protected isSetup: boolean; | ||
|
||
constructor(db: DatabaseType, serde?: SerializerProtocol) { | ||
super(serde); | ||
this.db = db; | ||
this.isSetup = false; | ||
} | ||
|
||
protected setup(): void { | ||
if (this.isSetup) { | ||
return; | ||
} | ||
|
||
this.db.exec(` | ||
CREATE TABLE IF NOT EXISTS checkpoints ( | ||
thread_id TEXT NOT NULL, | ||
checkpoint_ns TEXT NOT NULL DEFAULT '', | ||
checkpoint_id TEXT NOT NULL, | ||
parent_checkpoint_id TEXT, | ||
type TEXT, | ||
checkpoint BLOB, | ||
metadata BLOB, | ||
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) | ||
);`); | ||
this.db.exec(` | ||
CREATE TABLE IF NOT EXISTS writes ( | ||
thread_id TEXT NOT NULL, | ||
checkpoint_ns TEXT NOT NULL DEFAULT '', | ||
checkpoint_id TEXT NOT NULL, | ||
task_id TEXT NOT NULL, | ||
idx INTEGER NOT NULL, | ||
channel TEXT NOT NULL, | ||
type TEXT, | ||
value BLOB, | ||
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) | ||
);`); | ||
|
||
this.isSetup = true; | ||
} | ||
|
||
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> { | ||
this.setup(); | ||
const { | ||
thread_id, | ||
checkpoint_ns = "", | ||
checkpoint_id, | ||
} = config.configurable ?? {}; | ||
let row: CheckpointRow | null; | ||
if (checkpoint_id) { | ||
row = await this.db | ||
.prepare( | ||
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`, | ||
) | ||
.bind(thread_id, checkpoint_ns, checkpoint_id) | ||
.first<CheckpointRow>(); | ||
} else { | ||
row = await this.db | ||
.prepare( | ||
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1`, | ||
) | ||
.bind(thread_id, checkpoint_ns) | ||
.first<CheckpointRow>(); | ||
} | ||
if (row === null) { | ||
return undefined; | ||
} | ||
let finalConfig = config; | ||
if (!checkpoint_id) { | ||
finalConfig = { | ||
configurable: { | ||
thread_id: row.thread_id, | ||
checkpoint_ns, | ||
checkpoint_id: row.checkpoint_id, | ||
}, | ||
}; | ||
} | ||
if ( | ||
finalConfig.configurable?.thread_id === undefined || | ||
finalConfig.configurable?.checkpoint_id === undefined | ||
) { | ||
throw new Error("Missing thread_id or checkpoint_id"); | ||
} | ||
// find any pending writes | ||
const pendingWritesRows = await this.db | ||
.prepare( | ||
`SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`, | ||
) | ||
.bind( | ||
finalConfig.configurable.thread_id.toString(), | ||
checkpoint_ns, | ||
finalConfig.configurable.checkpoint_id.toString(), | ||
) | ||
.all<WritesRow>(); | ||
const pendingWrites = await Promise.all( | ||
pendingWritesRows.results.map(async (row) => { | ||
return [ | ||
row.task_id, | ||
row.channel, | ||
await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""), | ||
] as [string, string, unknown]; | ||
}), | ||
); | ||
return { | ||
config: finalConfig, | ||
checkpoint: (await this.serde.loadsTyped( | ||
row.type ?? "json", | ||
row.checkpoint, | ||
)) as Checkpoint, | ||
metadata: (await this.serde.loadsTyped( | ||
row.type ?? "json", | ||
row.metadata, | ||
)) as CheckpointMetadata, | ||
parentConfig: row.parent_checkpoint_id | ||
? { | ||
configurable: { | ||
thread_id: row.thread_id, | ||
checkpoint_ns, | ||
checkpoint_id: row.parent_checkpoint_id, | ||
}, | ||
} | ||
: undefined, | ||
pendingWrites, | ||
}; | ||
} | ||
|
||
async *list( | ||
config: RunnableConfig, | ||
options?: CheckpointListOptions, | ||
): AsyncGenerator<CheckpointTuple> { | ||
const { limit, before, filter } = options ?? {}; | ||
this.setup(); | ||
const thread_id = config.configurable?.thread_id; | ||
const checkpoint_ns = config.configurable?.checkpoint_ns; | ||
|
||
let sql = | ||
`SELECT\n` + | ||
" thread_id,\n" + | ||
" checkpoint_ns,\n" + | ||
" checkpoint_id,\n" + | ||
" parent_checkpoint_id,\n" + | ||
" type,\n" + | ||
" checkpoint,\n" + | ||
" metadata\n" + | ||
"FROM checkpoints\n"; | ||
|
||
const whereClause: string[] = []; | ||
|
||
if (thread_id) { | ||
whereClause.push("thread_id = ?"); | ||
} | ||
|
||
if (checkpoint_ns !== undefined && checkpoint_ns !== null) { | ||
whereClause.push("checkpoint_ns = ?"); | ||
} | ||
|
||
if (before?.configurable?.checkpoint_id !== undefined) { | ||
whereClause.push("checkpoint_id < ?"); | ||
} | ||
|
||
const sanitizedFilter = Object.fromEntries( | ||
Object.entries(filter ?? {}).filter( | ||
([key, value]) => | ||
value !== undefined && | ||
validCheckpointMetadataKeys.includes(key as keyof CheckpointMetadata), | ||
), | ||
); | ||
|
||
whereClause.push( | ||
...Object.entries(sanitizedFilter).map( | ||
([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?`, | ||
), | ||
); | ||
|
||
if (whereClause.length > 0) { | ||
sql += `WHERE\n ${whereClause.join(" AND\n ")}\n`; | ||
} | ||
|
||
sql += "\nORDER BY checkpoint_id DESC"; | ||
|
||
if (limit) { | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided | ||
} | ||
|
||
const args = [ | ||
thread_id, | ||
checkpoint_ns, | ||
before?.configurable?.checkpoint_id, | ||
...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)), | ||
].filter((value) => value !== undefined && value !== null); | ||
|
||
const rows = await this.db | ||
.prepare(sql) | ||
.bind(...args) | ||
.all<CheckpointRow>(); | ||
|
||
if (rows) { | ||
for (const row of rows.results) { | ||
yield { | ||
config: { | ||
configurable: { | ||
thread_id: row.thread_id, | ||
checkpoint_ns: row.checkpoint_ns, | ||
checkpoint_id: row.checkpoint_id, | ||
}, | ||
}, | ||
checkpoint: (await this.serde.loadsTyped( | ||
row.type ?? "json", | ||
row.checkpoint, | ||
)) as Checkpoint, | ||
metadata: (await this.serde.loadsTyped( | ||
row.type ?? "json", | ||
row.metadata, | ||
)) as CheckpointMetadata, | ||
parentConfig: row.parent_checkpoint_id | ||
? { | ||
configurable: { | ||
thread_id: row.thread_id, | ||
checkpoint_ns: row.checkpoint_ns, | ||
checkpoint_id: row.parent_checkpoint_id, | ||
}, | ||
} | ||
: undefined, | ||
}; | ||
} | ||
} | ||
} | ||
|
||
async put( | ||
config: RunnableConfig, | ||
checkpoint: Checkpoint, | ||
metadata: CheckpointMetadata, | ||
): Promise<RunnableConfig> { | ||
this.setup(); | ||
|
||
const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); | ||
const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); | ||
if (type1 !== type2) { | ||
throw new Error( | ||
"Failed to serialized checkpoint and metadata to the same type.", | ||
); | ||
} | ||
const row = [ | ||
config.configurable?.thread_id?.toString(), | ||
config.configurable?.checkpoint_ns, | ||
checkpoint.id, | ||
config.configurable?.checkpoint_id, | ||
type1, | ||
serializedCheckpoint, | ||
serializedMetadata, | ||
]; | ||
|
||
await this.db | ||
.prepare( | ||
`INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)`, | ||
) | ||
.bind(...row) | ||
.all(); | ||
|
||
return { | ||
configurable: { | ||
thread_id: config.configurable?.thread_id, | ||
checkpoint_ns: config.configurable?.checkpoint_ns, | ||
checkpoint_id: checkpoint.id, | ||
}, | ||
}; | ||
} | ||
|
||
async putWrites( | ||
config: RunnableConfig, | ||
writes: PendingWrite[], | ||
taskId: string, | ||
): Promise<void> { | ||
this.setup(); | ||
|
||
const rows = writes.map((write, idx) => { | ||
const [type, serializedWrite] = this.serde.dumpsTyped(write[1]); | ||
const args = [ | ||
config.configurable?.thread_id, | ||
config.configurable?.checkpoint_ns, | ||
config.configurable?.checkpoint_id, | ||
taskId, | ||
idx, | ||
write[0], | ||
type, | ||
serializedWrite, | ||
]; | ||
return this.db | ||
.prepare(` | ||
INSERT OR REPLACE INTO writes | ||
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) | ||
VALUES (?, ?, ?, ?, ?, ?, ?, ?) | ||
`) | ||
.bind(...args); | ||
}); | ||
|
||
await this.db.batch(rows); | ||
} | ||
} |