diff --git a/Cargo.lock b/Cargo.lock index b51e07de2..6192b80b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,6 +123,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.18" @@ -224,6 +239,21 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "argminmax" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" + [[package]] name = "arrayref" version = "0.3.7" @@ -375,7 +405,7 @@ dependencies = [ "arrow-schema 51.0.0", "arrow-select 51.0.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -396,7 +426,7 @@ dependencies = [ "arrow-schema 52.2.0", "arrow-select 52.2.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -709,6 +739,15 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + [[package]] name = "async-executor" version = "1.5.1" @@ -767,6 +806,32 @@ dependencies = [ "event-listener 2.5.3", ] +[[package]] +name = "async-openai" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0e5ff98f9e7c605df4c88783a0439d1dc667ce86bd79e99d4164f8b0c05ccc" +dependencies = [ + "async-convert", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder 0.20.1", + "eventsource-stream", + "futures", + "rand 0.8.5", + "reqwest 0.12.5", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util 0.7.11", + "tracing", +] + [[package]] name = "async-priority-channel" version = "0.2.0" @@ -832,6 +897,28 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "async-task" version = "4.4.0" @@ -875,6 +962,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atoi_simd" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" + [[package]] name = "atomic-waker" version = "1.1.1" @@ -1384,6 +1477,20 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.10", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -1435,9 +1542,9 @@ checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64-simd" @@ -1623,6 +1730,27 @@ dependencies = [ "syn_derive", ] +[[package]] +name = "brotli" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19483b140a7ac7174d34b5a581b406c64f84da5409d3e09cf4fff604f9270e67" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bs58" version = "0.5.0" @@ -1633,6 +1761,17 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "regex-automata 0.4.7", + "serde", +] + [[package]] name = "buf_redux" version = "0.8.4" @@ -1700,6 +1839,20 @@ name = "bytemuck" version = "1.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] [[package]] name = "byteorder" @@ -1915,6 +2068,17 @@ dependencies = [ "parse-zoneinfo", ] +[[package]] +name = "chrono-tz" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +dependencies = [ + "chrono", + "chrono-tz-build 0.2.1", + "phf 0.11.2", +] + [[package]] name = "chrono-tz" version = "0.9.0" @@ -1922,8 +2086,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" dependencies = [ "chrono", - "chrono-tz-build", + "chrono-tz-build 0.3.0", + "phf 0.11.2", +] + +[[package]] +name = "chrono-tz-build" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +dependencies = [ + "parse-zoneinfo", "phf 0.11.2", + "phf_codegen 0.11.2", ] [[package]] @@ -2153,6 +2328,7 @@ version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ + "crossterm", "strum 0.26.3", "strum_macros 0.26.4", "unicode-width", @@ -2165,7 +2341,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0436149c9f6a1935b13306206c739b1ba84fa81f551b5eb87fc2ca7a13700af" dependencies = [ "clap 4.5.4", - "derive_builder", + "derive_builder 0.12.0", "entities", "memchr", "once_cell", @@ -2451,6 +2627,28 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.4.0", + "crossterm_winapi", + "libc", + "parking_lot 0.12.1", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -2819,7 +3017,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a369332afd0ef5bd565f6db2139fb9f1dfdd0afa75a7f70f000b74208d76994f" dependencies = [ "arrow 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -2906,7 +3104,7 @@ dependencies = [ "arrow-ord 52.2.0", "arrow-schema 52.2.0", "arrow-string 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -3082,7 +3280,16 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" dependencies = [ - "derive_builder_macro", + "derive_builder_macro 0.12.0", +] + +[[package]] +name = "derive_builder" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b" +dependencies = [ + "derive_builder_macro 0.20.1", ] [[package]] @@ -3097,16 +3304,38 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_builder_core" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "derive_builder_macro" version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" dependencies = [ - "derive_builder_core", + "derive_builder_core 0.12.0", "syn 1.0.109", ] +[[package]] +name = "derive_builder_macro" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" +dependencies = [ + "derive_builder_core 0.20.1", + "syn 2.0.66", +] + [[package]] name = "derive_more" version = "0.99.18" @@ -3278,6 +3507,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "ecdsa" version = "0.14.8" @@ -3337,9 +3572,9 @@ checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" [[package]] name = "either" -version = "1.9.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "elliptic-curve" @@ -3422,6 +3657,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "env_logger" version = "0.9.3" @@ -3778,6 +4025,12 @@ dependencies = [ "yansi", ] +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + [[package]] name = "event-listener" version = "2.5.3" @@ -3795,6 +4048,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.72.0" @@ -3821,6 +4085,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -3831,6 +4101,22 @@ dependencies = [ "regex", ] +[[package]] +name = "fancy-regex" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" +dependencies = [ + "bit-set", + "regex", +] + +[[package]] +name = "fast-float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" + [[package]] name = "fastdivide" version = "0.4.1" @@ -3993,6 +4279,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -4441,6 +4733,8 @@ checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash 0.8.11", "allocator-api2", + "rayon", + "serde", ] [[package]] @@ -4521,9 +4815,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -4854,7 +5148,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows 0.48.0", ] [[package]] @@ -5055,7 +5349,7 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", "windows-sys 0.48.0", ] @@ -5088,7 +5382,7 @@ version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "rustix 0.38.32", "windows-sys 0.48.0", ] @@ -5141,6 +5435,12 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "itoap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" + [[package]] name = "jetscii" version = "0.5.3" @@ -6028,11 +6328,21 @@ version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9106e1d747ffd48e6be5bb2d97fa706ed25b144fbee4d5c02eae110cd8d6badd" +[[package]] +name = "lz4" +version = "1.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958b4caa893816eea05507c20cfe47574a43d9a697138a7872990bba8a0ece68" +dependencies = [ + "libc", + "lz4-sys", +] + [[package]] name = "lz4-sys" -version = "1.9.4" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +checksum = "109de74d5d2353660401699a4174a4ff23fcc649caf553df71933c7fb45ad868" dependencies = [ "cc", "libc", @@ -6181,6 +6491,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + [[package]] name = "memmap2" version = "0.9.4" @@ -6297,13 +6616,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ + "hermit-abi 0.3.9", "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -6410,19 +6730,41 @@ dependencies = [ ] [[package]] -name = "murmurhash32" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" - -[[package]] -name = "mustache" -version = "0.9.0" +name = "multiversion" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51956ef1c5d20a1384524d91e616fb44dfc7d8f249bf696d49c97dd3289ecab5" +checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" dependencies = [ - "log 0.3.9", - "serde", + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "target-features", +] + +[[package]] +name = "murmurhash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" + +[[package]] +name = "mustache" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51956ef1c5d20a1384524d91e616fb44dfc7d8f249bf696d49c97dd3289ecab5" +dependencies = [ + "log 0.3.9", + "serde", ] [[package]] @@ -6524,6 +6866,24 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -6642,7 +7002,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", ] @@ -6683,7 +7043,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" dependencies = [ "async-trait", - "base64 0.22.0", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -7088,6 +7448,16 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parquet-format-safe" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" +dependencies = [ + "async-trait", + "futures", +] + [[package]] name = "parse-zoneinfo" version = "0.3.0" @@ -7554,6 +7924,15 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + [[package]] name = "platforms" version = "3.2.0" @@ -7608,6 +7987,416 @@ dependencies = [ "miniz_oxide 0.7.1", ] +[[package]] +name = "polars" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3351ea4570e54cd556e6755b78fe7a2c85368d820c0307cca73c96e796a7ba" +dependencies = [ + "getrandom 0.2.10", + "polars-arrow", + "polars-core", + "polars-error", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-parquet", + "polars-sql", + "polars-time", + "polars-utils", + "version_check 0.9.4", +] + +[[package]] +name = "polars-arrow" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba65fc4bcabbd64fca01fd30e759f8b2043f0963c57619e331d4b534576c0b47" +dependencies = [ + "ahash 0.8.11", + "atoi", + "atoi_simd", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "dyn-clone", + "either", + "ethnum", + "fast-float", + "foreign_vec", + "futures", + "getrandom 0.2.10", + "hashbrown 0.14.5", + "itoa 1.0.9", + "itoap", + "lz4", + "multiversion", + "num-traits", + "polars-arrow-format", + "polars-error", + "polars-utils", + "ryu", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "version_check 0.9.4", + "zstd 0.13.2", +] + +[[package]] +name = "polars-arrow-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b0ef2474af9396b19025b189d96e992311e6a47f90c53cd998b36c4c64b84c" +dependencies = [ + "planus", + "serde", +] + +[[package]] +name = "polars-compute" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f099516af30ac9ae4b4480f4ad02aa017d624f2f37b7a16ad4e9ba52f7e5269" +dependencies = [ + "bytemuck", + "either", + "num-traits", + "polars-arrow", + "polars-error", + "polars-utils", + "strength_reduce", + "version_check 0.9.4", +] + +[[package]] +name = "polars-core" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2439484be228b8c302328e2f953e64cfd93930636e5c7ceed90339ece7fef6c" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "comfy-table", + "either", + "hashbrown 0.14.5", + "indexmap 2.1.0", + "num-traits", + "once_cell", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-row", + "polars-utils", + "rand 0.8.5", + "rand_distr", + "rayon", + "regex", + "smartstring", + "thiserror", + "version_check 0.9.4", + "xxhash-rust", +] + +[[package]] +name = "polars-error" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c9b06dfbe79cabe50a7f0a90396864b5ee2c0e0f8d6a9353b2343c29c56e937" +dependencies = [ + "polars-arrow-format", + "regex", + "simdutf8", + "thiserror", +] + +[[package]] +name = "polars-expr" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c630385a56a867c410a20f30772d088f90ec3d004864562b84250b35268f97" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "once_cell", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", +] + +[[package]] +name = "polars-io" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d7363cd14e4696a28b334a56bd11013ff49cc96064818ab3f91a126e453462d" +dependencies = [ + "ahash 0.8.11", + "async-trait", + "atoi_simd", + "bytes", + "chrono", + "fast-float", + "futures", + "home", + "itoa 1.0.9", + "memchr", + "memmap2 0.7.1", + "num-traits", + "once_cell", + "percent-encoding 2.3.1", + "polars-arrow", + "polars-core", + "polars-error", + "polars-parquet", + "polars-time", + "polars-utils", + "rayon", + "regex", + "ryu", + "simdutf8", + "smartstring", + "tokio", + "tokio-util 0.7.11", +] + +[[package]] +name = "polars-lazy" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03877e74e42b5340ae52ded705f6d5d14563d90554c9177b01b91ed2412a56ed" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "glob", + "memchr", + "once_cell", + "polars-arrow", + "polars-core", + "polars-expr", + "polars-io", + "polars-mem-engine", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", + "version_check 0.9.4", +] + +[[package]] +name = "polars-mem-engine" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea9e17771af750c94bf959885e4b3f5b14149576c62ef3ec1c9ef5827b2a30f" +dependencies = [ + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", +] + +[[package]] +name = "polars-ops" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6066552eb577d43b307027fb38096910b643ffb2c89a21628c7e41caf57848d0" +dependencies = [ + "ahash 0.8.11", + "argminmax", + "base64 0.22.1", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "hex", + "indexmap 2.1.0", + "memchr", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-utils", + "rayon", + "regex", + "smartstring", + "unicode-reverse", + "version_check 0.9.4", +] + +[[package]] +name = "polars-parquet" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b35b2592a2e7ef7ce9942dc2120dc4576142626c0e661668e4c6b805042e461" +dependencies = [ + "ahash 0.8.11", + "async-stream", + "base64 0.22.1", + "brotli", + "ethnum", + "flate2", + "futures", + "lz4", + "num-traits", + "parquet-format-safe", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-utils", + "simdutf8", + "snap", + "streaming-decompression", + "zstd 0.13.2", +] + +[[package]] +name = "polars-pipe" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021bce7768c330687d735340395a77453aa18dd70d57c184cbb302311e87c1b9" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "hashbrown 0.14.5", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "smartstring", + "uuid 1.8.0", + "version_check 0.9.4", +] + +[[package]] +name = "polars-plan" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "220d0d7c02d1c4375802b2813dbedcd1a184df39c43b74689e729ede8d5c2921" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "once_cell", + "percent-encoding 2.3.1", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-parquet", + "polars-time", + "polars-utils", + "rayon", + "recursive", + "regex", + "smartstring", + "strum_macros 0.26.4", + "version_check 0.9.4", +] + +[[package]] +name = "polars-row" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1d70d87a2882a64a43b431aea1329cb9a2c4100547c95c417cc426bb82408b3" +dependencies = [ + "bytemuck", + "polars-arrow", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-sql" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6fc1c9b778862f09f4a347f768dfdd3d0ba9957499d306d83c7103e0fa8dc5b" +dependencies = [ + "hex", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-lazy", + "polars-ops", + "polars-plan", + "polars-time", + "rand 0.8.5", + "serde", + "serde_json", + "sqlparser", +] + +[[package]] +name = "polars-time" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "179f98313a15c0bfdbc8cc0f1d3076d08d567485b9952d46439f94fbc3085df5" +dependencies = [ + "atoi", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "now", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "regex", + "smartstring", +] + +[[package]] +name = "polars-utils" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53e6dd89fcccb1ec1a62f752c9a9f2d482a85e9255153f46efecc617b4996d50" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "hashbrown 0.14.5", + "indexmap 2.1.0", + "num-traits", + "once_cell", + "polars-error", + "raw-cpuid 11.0.1", + "rayon", + "smartstring", + "stacker", + "sysinfo", + "version_check 0.9.4", +] + [[package]] name = "polling" version = "2.8.0" @@ -7923,6 +8712,15 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "psm" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5787f7cda34e3033a72192c018bc5883100330f362ef279a8cbccfce8bb4e874" +dependencies = [ + "cc", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -8449,6 +9247,26 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn 2.0.66", +] + [[package]] name = "redox_syscall" version = "0.3.5" @@ -8596,7 +9414,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "bytes", "futures-channel", "futures-core", @@ -8612,6 +9430,7 @@ dependencies = [ "js-sys", "log 0.4.21", "mime 0.3.17", + "mime_guess 2.0.4", "once_cell", "percent-encoding 2.3.1", "pin-project-lite", @@ -8637,6 +9456,22 @@ dependencies = [ "winreg 0.52.0", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime 0.3.17", + "nom", + "pin-project-lite", + "reqwest 0.12.5", + "thiserror", +] + [[package]] name = "rfc6979" version = "0.3.1" @@ -9014,7 +9849,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "rustls-pki-types", ] @@ -9229,6 +10064,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -9370,7 +10215,7 @@ version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -9495,6 +10340,24 @@ dependencies = [ "dirs", ] +[[package]] +name = "shinkai-graphrag" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-openai", + "async-trait", + "futures", + "polars", + "polars-lazy", + "rand 0.8.5", + "reqwest 0.11.27", + "serde", + "serde_json", + "tiktoken-rs", + "tokio", +] + [[package]] name = "shinkai_crypto_identities" version = "0.1.1" @@ -9537,7 +10400,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -9717,7 +10580,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -9973,6 +10836,17 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg 1.1.0", + "static_assertions", + "version_check 0.9.4", +] + [[package]] name = "snafu" version = "0.7.5" @@ -9995,6 +10869,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.4.9" @@ -10100,6 +10980,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "winapi", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -10118,6 +11011,27 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978" +[[package]] +name = "streaming-decompression" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" +dependencies = [ + "fallible-streaming-iterator", +] + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "string_cache" version = "0.8.7" @@ -10280,7 +11194,7 @@ checksum = "874dcfa363995604333cf947ae9f751ca3af4522c60886774c4963943b4746b1" dependencies = [ "bincode", "bitflags 1.3.2", - "fancy-regex", + "fancy-regex 0.11.0", "flate2", "fnv", "once_cell", @@ -10295,6 +11209,20 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "sysinfo" +version = "0.30.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "windows 0.52.0", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -10343,7 +11271,7 @@ checksum = "f8d0582f186c0a6d55655d24543f15e43607299425c5ad8352c242b914b31856" dependencies = [ "aho-corasick", "arc-swap", - "base64 0.22.0", + "base64 0.22.1", "bitpacking", "byteorder", "census", @@ -10360,7 +11288,7 @@ dependencies = [ "lru 0.12.3", "lz4_flex", "measure_time", - "memmap2", + "memmap2 0.9.4", "num_cpus", "once_cell", "oneshot", @@ -10493,6 +11421,12 @@ dependencies = [ "xattr", ] +[[package]] +name = "target-features" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" + [[package]] name = "target-lexicon" version = "0.12.14" @@ -10622,6 +11556,21 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken-rs" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234" +dependencies = [ + "anyhow", + "base64 0.21.7", + "bstr", + "fancy-regex 0.12.0", + "lazy_static", + "parking_lot 0.12.1", + "rustc-hash", +] + [[package]] name = "time" version = "0.1.45" @@ -10722,22 +11671,21 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.36.0" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2 0.5.6", "tokio-macros", "tracing", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -10752,9 +11700,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", @@ -10794,9 +11742,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" dependencies = [ "futures-core", "pin-project-lite", @@ -11250,6 +12198,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-reverse" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "unicode-segmentation" version = "1.10.1" @@ -11758,6 +12715,25 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.3", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.3", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -11993,6 +12969,12 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" +[[package]] +name = "xxhash-rust" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" + [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/Cargo.toml b/Cargo.toml index b4fdec92c..7e4bb91c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,14 @@ members = [ "shinkai-libs/shinkai-dsl", "shinkai-libs/shinkai-sheet", "shinkai-libs/shinkai-fs-mirror", + "shinkai-libs/shinkai-graphrag", "shinkai-libs/shinkai-message-primitives", "shinkai-libs/shinkai-ocr", "shinkai-libs/shinkai-tcp-relayer", "shinkai-libs/shinkai-vector-resources", "shinkai-bin/*", "shinkai-cli-tools/*" -] +, "shinkai-libs/shinkai-graphrag"] resolver = "2" [workspace.package] diff --git a/shinkai-libs/shinkai-graphrag/.gitignore b/shinkai-libs/shinkai-graphrag/.gitignore new file mode 100644 index 000000000..74deb7343 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/.gitignore @@ -0,0 +1,2 @@ +.vscode +dataset \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml new file mode 100644 index 000000000..7975fa77c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "shinkai-graphrag" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.86" +async-trait = "0.1.74" +futures = "0.3.30" +polars = { version = "0.41.3", features = ["dtype-struct", "lazy", "parquet"] } +polars-lazy = "0.41.3" +rand = "0.8.5" +serde = { version = "1.0.188", features = ["derive"] } +serde_json = "1.0.117" +tokio = { version = "1.36", features = ["full"] } + +[dev-dependencies] +async-openai = "0.23.4" +reqwest = { version = "0.11.26", features = ["json"] } +tiktoken-rs = "0.5.9" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs new file mode 100644 index 000000000..a6a1a7485 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -0,0 +1,473 @@ +use std::{ + collections::{HashMap, HashSet}, + io::{Cursor, Read}, +}; + +use polars::{ + frame::DataFrame, + io::SerWriter, + prelude::{col, concat, CsvWriter, DataType, IntoLazy, LazyFrame, NamedFrom, SortMultipleOptions, UnionArgs}, + series::Series, +}; +use rand::prelude::SliceRandom; + +use super::{context_builder::ContextBuilderParams, indexer_entities::Entity, indexer_reports::CommunityReport}; + +pub struct GlobalCommunityContext { + community_reports: Vec, + entities: Option>, + num_tokens_fn: fn(&str) -> usize, +} + +impl GlobalCommunityContext { + pub fn new( + community_reports: Vec, + entities: Option>, + num_tokens_fn: fn(&str) -> usize, + ) -> Self { + Self { + community_reports, + entities, + num_tokens_fn, + } + } + + pub async fn build_context( + &self, + context_builder_params: ContextBuilderParams, + ) -> anyhow::Result<(Vec, HashMap)> { + let ContextBuilderParams { + use_community_summary, + column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + community_rank_name, + include_community_weight, + community_weight_name, + normalize_community_weight, + max_tokens, + context_name, + } = context_builder_params; + + let (community_context, community_context_data) = CommunityContext::build_community_context( + self.community_reports.clone(), + self.entities.clone(), + self.num_tokens_fn, + use_community_summary, + &column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + &community_rank_name, + include_community_weight, + &community_weight_name, + normalize_community_weight, + max_tokens, + false, + &context_name, + )?; + + let final_context = community_context; + let final_context_data = community_context_data; + + Ok((final_context, final_context_data)) + } +} + +pub struct CommunityContext {} + +impl CommunityContext { + pub fn build_community_context( + community_reports: Vec, + entities: Option>, + num_tokens_fn: fn(&str) -> usize, + use_community_summary: bool, + column_delimiter: &str, + shuffle_data: bool, + include_community_rank: bool, + min_community_rank: u32, + community_rank_name: &str, + include_community_weight: bool, + community_weight_name: &str, + normalize_community_weight: bool, + max_tokens: usize, + single_batch: bool, + context_name: &str, + ) -> anyhow::Result<(Vec, HashMap)> { + let _is_included = |report: &CommunityReport| -> bool { + report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() + }; + + let _get_header = |attributes: Vec| -> Vec { + let mut header = vec!["id".to_string(), "title".to_string()]; + let mut filtered_attributes: Vec = attributes + .iter() + .filter(|&col| !header.contains(&col.to_string())) + .cloned() + .collect(); + + if !include_community_weight { + filtered_attributes.retain(|col| col != community_weight_name); + } + + header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); + header.push(if use_community_summary { + "summary".to_string() + } else { + "content".to_string() + }); + + if include_community_rank { + header.push(community_rank_name.to_string()); + } + + header + }; + + let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { + let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + + for field in attributes { + let value = report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + context.push(value); + } + + context.push(if use_community_summary { + report.summary.clone() + } else { + report.full_content.clone() + }); + + if include_community_rank { + context.push(report.rank.unwrap_or_default().to_string()); + } + + let result = context.join(column_delimiter) + "\n"; + (result, context) + }; + + let compute_community_weights = entities.as_ref().is_some_and(|e| !e.is_empty()) + && !community_reports.is_empty() + && include_community_weight + && (community_reports[0].attributes.is_none() + || !community_reports[0] + .attributes + .as_ref() + .unwrap() + .contains_key(community_weight_name)); + + let mut community_reports = community_reports; + if compute_community_weights { + community_reports = Self::_compute_community_weights( + community_reports, + entities.clone(), + community_weight_name, + normalize_community_weight, + ); + } + + let mut selected_reports: Vec = community_reports + .iter() + .filter(|&report| _is_included(report)) + .cloned() + .collect(); + + if selected_reports.is_empty() { + return Ok((Vec::new(), HashMap::new())); + } + + if shuffle_data { + let mut rng = rand::thread_rng(); + selected_reports.shuffle(&mut rng); + } + + let attributes = if let Some(attributes) = &community_reports[0].attributes { + attributes.keys().cloned().collect::>() + } else { + Vec::new() + }; + + let header = _get_header(attributes.clone()); + let mut all_context_text: Vec = Vec::new(); + let mut all_context_records: Vec = Vec::new(); + + let mut batch = Batch::new(); + + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); + + for report in selected_reports { + let (new_context_text, new_context) = _report_context_text(&report, &attributes); + let new_tokens = num_tokens_fn(&new_context_text); + + // add the current batch to the context data and start a new batch if we are in multi-batch mode + if batch.batch_tokens + new_tokens > max_tokens { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + column_delimiter, + )?; + + if single_batch { + break; + } + + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); + } + + batch.batch_text.push_str(&new_context_text); + batch.batch_tokens += new_tokens; + batch.batch_records.push(new_context); + } + + if !all_context_text.contains(&batch.batch_text) { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + column_delimiter, + )?; + } + + if all_context_records.is_empty() { + eprintln!("Warning: No community records added when building community context."); + return Ok((Vec::new(), HashMap::new())); + } + + let records_concat = concat( + all_context_records + .into_iter() + .map(|df| df.lazy()) + .collect::>(), + UnionArgs::default(), + )? + .collect()?; + + Ok(( + all_context_text, + HashMap::from([(context_name.to_lowercase(), records_concat)]), + )) + } + + fn _compute_community_weights( + community_reports: Vec, + entities: Option>, + weight_attribute: &str, + normalize: bool, + ) -> Vec { + // Calculate a community's weight as the count of text units associated with entities within the community. + if let Some(entities) = entities { + let mut community_reports = community_reports; + let mut community_text_units = std::collections::HashMap::new(); + for entity in entities { + if let Some(community_ids) = entity.community_ids.clone() { + for community_id in community_ids { + community_text_units + .entry(community_id) + .or_insert_with(Vec::new) + .extend(entity.text_unit_ids.clone()); + } + } + } + for report in &mut community_reports { + if report.attributes.is_none() { + report.attributes = Some(std::collections::HashMap::new()); + } + if let Some(attributes) = &mut report.attributes { + attributes.insert( + weight_attribute.to_string(), + community_text_units + .get(&report.community_id) + .map(|text_units| text_units.iter().flatten().cloned().collect::>().len()) + .unwrap_or(0) + .to_string(), + ); + } + } + if normalize { + // Normalize by max weight + let all_weights: Vec = community_reports + .iter() + .filter_map(|report| { + report + .attributes + .as_ref() + .and_then(|attributes| attributes.get(weight_attribute)) + .map(|weight| weight.parse::().unwrap_or(0.0)) + }) + .collect(); + if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { + for report in &mut community_reports { + if let Some(attributes) = &mut report.attributes { + if let Some(weight) = attributes.get_mut(weight_attribute) { + *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); + } + } + } + } + } + + return community_reports; + } + community_reports + } +} + +struct Batch { + batch_text: String, + batch_tokens: usize, + batch_records: Vec>, +} + +impl Batch { + fn new() -> Self { + Batch { + batch_text: String::new(), + batch_tokens: 0, + batch_records: Vec::new(), + } + } + + fn init_batch( + &mut self, + context_name: &str, + header: &Vec, + column_delimiter: &str, + num_tokens_fn: fn(&str) -> usize, + ) { + self.batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); + self.batch_tokens = num_tokens_fn(&self.batch_text); + self.batch_records.clear(); + } + + fn cut_batch( + &mut self, + all_context_text: &mut Vec, + all_context_records: &mut Vec, + entities: Option>, + header: &Vec, + community_weight_name: &str, + community_rank_name: &str, + include_community_weight: bool, + include_community_rank: bool, + column_delimiter: &str, + ) -> anyhow::Result<()> { + let weight_column = if include_community_weight && entities.as_ref().is_some_and(|e| !e.is_empty()) { + Some(community_weight_name) + } else { + None + }; + let rank_column = if include_community_rank { + Some(community_rank_name) + } else { + None + }; + + let mut record_df = Self::_convert_report_context_to_df( + self.batch_records.clone(), + header.clone(), + weight_column, + rank_column, + )?; + if record_df.is_empty() { + return Ok(()); + } + + let column_delimiter = if column_delimiter.is_empty() { + b'|' + } else { + column_delimiter.as_bytes()[0] + }; + + let mut buffer = Cursor::new(Vec::new()); + CsvWriter::new(&mut buffer) + .include_header(true) + .with_separator(column_delimiter) + .finish(&mut record_df)?; + + let mut current_context_text = String::new(); + buffer.set_position(0); + buffer.read_to_string(&mut current_context_text)?; + + all_context_text.push(current_context_text); + all_context_records.push(record_df); + + Ok(()) + } + + fn _convert_report_context_to_df( + context_records: Vec>, + header: Vec, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + if context_records.is_empty() { + return Ok(DataFrame::empty()); + } + + let mut data_series = Vec::new(); + for (index, header) in header.iter().enumerate() { + let records = context_records + .iter() + .map(|r| r.get(index).unwrap_or(&String::new()).to_owned()) + .collect::>(); + let series = Series::new(header, records); + data_series.push(series); + } + + let record_df = DataFrame::new(data_series)?; + + return Self::_rank_report_context(record_df, weight_column, rank_column); + } + + fn _rank_report_context( + report_df: DataFrame, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + let mut rank_attributes = Vec::new(); + + let mut report_df = report_df; + + if let Some(weight_column) = weight_column { + rank_attributes.push(weight_column); + report_df = report_df + .lazy() + .with_column(col(weight_column).cast(DataType::Float64)) + .collect()?; + } + + if let Some(rank_column) = rank_column { + rank_attributes.push(rank_column); + report_df = report_df + .lazy() + .with_column(col(rank_column).cast(DataType::Float64)) + .collect()?; + } + + if !rank_attributes.is_empty() { + report_df = report_df + .lazy() + .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) + .collect()?; + } + + Ok(report_df) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs new file mode 100644 index 000000000..87db20231 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs @@ -0,0 +1,19 @@ +#[derive(Debug, Clone)] +pub struct ContextBuilderParams { + //conversation_history: Option, + pub use_community_summary: bool, + pub column_delimiter: String, + pub shuffle_data: bool, + pub include_community_rank: bool, + pub min_community_rank: u32, + pub community_rank_name: String, + pub include_community_weight: bool, + pub community_weight_name: String, + pub normalize_community_weight: bool, + pub max_tokens: usize, + pub context_name: String, + // conversation_history_user_turns_only: bool, + // conversation_history_max_turns: Option, +} + +pub struct ConversationHistory {} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs new file mode 100644 index 000000000..26d8566fd --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -0,0 +1,262 @@ +use std::collections::{HashMap, HashSet}; + +use polars::prelude::*; +use polars_lazy::dsl::col; +use serde::{Deserialize, Serialize}; + +use super::indexer_reports::filter_under_community_level; + +pub fn read_indexer_entities( + final_nodes: &DataFrame, + final_entities: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let entity_embedding_df = final_entities.clone(); + + let entity_df = entity_df + .lazy() + .rename(["title", "degree"], ["name", "rank"]) + .with_column(col("community").fill_null(lit(-1))) + .with_column(col("community").cast(DataType::Int32)) + .with_column(col("rank").cast(DataType::Int32)) + .group_by([col("name"), col("rank")]) + .agg([col("community").max()]) + .with_column(col("community").cast(DataType::String)) + .join( + entity_embedding_df.lazy(), + [col("name")], + [col("name")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let entities = read_entities( + entity_df, + "id", + Some("human_readable_id"), + "name", + Some("type"), + Some("description"), + None, + Some("description_embedding"), + None, + Some("community"), + Some("text_unit_ids"), + None, + Some("rank"), + )?; + + Ok(entities) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Entity { + pub id: String, + pub short_id: Option, + pub title: String, + pub entity_type: Option, + pub description: Option, + pub description_embedding: Option>, + pub name_embedding: Option>, + pub graph_embedding: Option>, + pub community_ids: Option>, + pub text_unit_ids: Option>, + pub document_ids: Option>, + pub rank: Option, + pub attributes: Option>, +} + +pub fn read_entities( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + type_col: Option<&str>, + description_col: Option<&str>, + name_embedding_col: Option<&str>, + description_embedding_col: Option<&str>, + graph_embedding_col: Option<&str>, + community_col: Option<&str>, + text_unit_ids_col: Option<&str>, + document_ids_col: Option<&str>, + rank_col: Option<&str>, + // attributes_cols: Option>, +) -> anyhow::Result> { + let column_names = [ + Some(id_col), + short_id_col, + Some(title_col), + type_col, + description_col, + name_embedding_col, + description_embedding_col, + graph_embedding_col, + community_col, + text_unit_ids_col, + document_ids_col, + rank_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); + + let column_names = column_names.into_iter().collect::>().into_vec(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut entities = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = Entity { + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + title: get_field(&row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or(String::new()), + entity_type: type_col + .map(|type_col| get_field(&row, type_col, &column_names)) + .flatten() + .map(|entity_type| entity_type.to_string()), + description: description_col + .map(|description_col| get_field(&row, description_col, &column_names)) + .flatten() + .map(|description| description.to_string()), + name_embedding: name_embedding_col.map(|name_embedding_col| { + get_field(&row, name_embedding_col, &column_names) + .map(|name_embedding| match name_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(name_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + description_embedding: description_embedding_col.map(|description_embedding_col| { + get_field(&row, description_embedding_col, &column_names) + .map(|description_embedding| match description_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(description_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + graph_embedding: graph_embedding_col.map(|graph_embedding_col| { + get_field(&row, graph_embedding_col, &column_names) + .map(|graph_embedding| match graph_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(graph_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + community_ids: community_col.map(|community_col| { + get_field(&row, community_col, &column_names) + .map(|community_ids| match community_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { + get_field(&row, text_unit_ids_col, &column_names) + .map(|text_unit_ids| match text_unit_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(&row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + rank: rank_col + .map(|rank_col| { + get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) + }) + .flatten(), + attributes: None, + }; + entities.push(report); + } + + let mut unique_entities: Vec = Vec::new(); + let mut entity_ids: HashSet = HashSet::new(); + + for entity in entities { + if !entity_ids.contains(&entity.id) { + unique_entities.push(entity.clone()); + entity_ids.insert(entity.id); + } + } + + Ok(unique_entities) +} + +pub fn get_field<'a>( + row: &'a Vec>, + column_name: &'a str, + column_names: &'a Vec, +) -> Option> { + match column_names.iter().position(|x| x == column_name) { + Some(index) => row.get(index).cloned(), + None => None, + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs new file mode 100644 index 000000000..1b59b9c59 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -0,0 +1,172 @@ +use std::collections::{HashMap, HashSet}; + +use polars::prelude::*; +use polars_lazy::dsl::col; +use serde::{Deserialize, Serialize}; + +use super::indexer_entities::get_field; + +pub fn read_indexer_reports( + final_community_reports: &DataFrame, + final_nodes: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let filtered_community_df = entity_df + .lazy() + .with_column(col("community").fill_null(lit(-1))) + .with_column(col("community").cast(DataType::Int32)) + .group_by([col("title")]) + .agg([col("community").max()]) + .with_column(col("community").cast(DataType::String)) + .filter(len().over([col("community")]).gt(lit(1))) + .collect()?; + + let report_df = final_community_reports.clone(); + let report_df = filter_under_community_level(&report_df, community_level)?; + + let report_df = report_df + .lazy() + .join( + filtered_community_df.lazy(), + [col("community")], + [col("community")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let reports = read_community_reports( + report_df, + "community", + Some("community"), + "title", + "community", + "summary", + "full_content", + Some("rank"), + None, + None, + )?; + Ok(reports) +} + +pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { + let mask = df.column("level")?.i64()?.lt_eq(community_level); + let result = df.filter(&mask)?; + + Ok(result) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CommunityReport { + pub id: String, + pub short_id: Option, + pub title: String, + pub community_id: String, + pub summary: String, + pub full_content: String, + pub rank: Option, + pub summary_embedding: Option>, + pub full_content_embedding: Option>, + pub attributes: Option>, +} + +pub fn read_community_reports( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + community_col: &str, + summary_col: &str, + content_col: &str, + rank_col: Option<&str>, + _summary_embedding_col: Option<&str>, + _content_embedding_col: Option<&str>, + // attributes_cols: Option<&[&str]>, +) -> anyhow::Result> { + let column_names = [ + Some(id_col), + short_id_col, + Some(title_col), + Some(community_col), + Some(summary_col), + Some(content_col), + rank_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); + + let column_names: Vec = column_names.into_iter().collect::>().into_vec(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut reports = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = CommunityReport { + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + title: get_field(&row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or(String::new()), + community_id: get_field(&row, community_col, &column_names) + .map(|community| community.to_string()) + .unwrap_or(String::new()), + summary: get_field(&row, summary_col, &column_names) + .map(|summary| summary.to_string()) + .unwrap_or(String::new()), + full_content: get_field(&row, content_col, &column_names) + .map(|content| content.to_string()) + .unwrap_or(String::new()), + rank: rank_col + .map(|rank_col| { + get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }) + .flatten(), + summary_embedding: None, + full_content_embedding: None, + attributes: None, + }; + reports.push(report); + } + + let mut unique_reports: Vec = Vec::new(); + let mut report_ids: HashSet = HashSet::new(); + + for report in reports { + if !report_ids.contains(&report.id) { + unique_reports.push(report.clone()); + report_ids.insert(report.id); + } + } + + Ok(unique_reports) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs new file mode 100644 index 000000000..0abed5320 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -0,0 +1,4 @@ +pub mod community_context; +pub mod context_builder; +pub mod indexer_entities; +pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/lib.rs b/shinkai-libs/shinkai-graphrag/src/lib.rs new file mode 100644 index 000000000..08bc3d655 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/lib.rs @@ -0,0 +1,3 @@ +pub mod context_builder; +pub mod llm; +pub mod search; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs new file mode 100644 index 000000000..5fa5cf633 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -0,0 +1,55 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +pub struct BaseLLMCallback { + pub response: Vec, +} + +impl BaseLLMCallback { + pub fn new() -> Self { + BaseLLMCallback { response: Vec::new() } + } + + pub fn on_llm_new_token(&mut self, token: &str) { + self.response.push(token.to_string()); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageType { + String(String), + Strings(Vec), + Dictionary(Vec>), +} + +#[derive(Debug, Clone)] +pub struct LLMParams { + pub max_tokens: u32, + pub temperature: f32, + pub response_format: HashMap, +} + +#[async_trait] +pub trait BaseLLM { + async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + search_phase: Option, + ) -> anyhow::Result; +} + +#[async_trait] +pub trait BaseTextEmbedding { + async fn aembed(&self, text: &str) -> Vec; +} + +pub enum GlobalSearchPhase { + Map, + Reduce, +} diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs new file mode 100644 index 000000000..214bbef7c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -0,0 +1 @@ +pub mod llm; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs new file mode 100644 index 000000000..8bb60b955 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -0,0 +1,434 @@ +use futures::future::join_all; +use polars::frame::DataFrame; +use serde_json::Value; +use std::collections::HashMap; +use std::time::Instant; + +use crate::context_builder::community_context::GlobalCommunityContext; +use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use crate::search::global_search::prompts::NO_DATA_ANSWER; + +use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; + +#[derive(Debug, Clone)] +pub struct SearchResult { + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, +} + +#[derive(Debug, Clone)] +pub enum ResponseType { + String(String), + KeyPoints(Vec), +} + +#[derive(Debug, Clone)] +pub enum ContextData { + String(String), + DataFrames(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Clone)] +pub enum ContextText { + String(String), + Strings(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Clone)] +pub struct KeyPoint { + pub answer: String, + pub score: i32, +} + +pub struct GlobalSearchResult { + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, + pub map_responses: Vec, + pub reduce_context_data: ContextData, + pub reduce_context_text: ContextText, +} + +#[derive(Debug, Clone)] +pub struct GlobalSearchLLMCallback { + response: Vec, + map_response_contexts: Vec, + map_response_outputs: Vec, +} + +impl GlobalSearchLLMCallback { + pub fn new() -> Self { + GlobalSearchLLMCallback { + response: Vec::new(), + map_response_contexts: Vec::new(), + map_response_outputs: Vec::new(), + } + } + + pub fn on_map_response_start(&mut self, map_response_contexts: Vec) { + self.map_response_contexts = map_response_contexts; + } + + pub fn on_map_response_end(&mut self, map_response_outputs: Vec) { + self.map_response_outputs = map_response_outputs; + } +} + +pub struct GlobalSearch { + llm: Box, + context_builder: GlobalCommunityContext, + num_tokens_fn: fn(&str) -> usize, + context_builder_params: ContextBuilderParams, + map_system_prompt: String, + reduce_system_prompt: String, + response_type: String, + allow_general_knowledge: bool, + general_knowledge_inclusion_prompt: String, + callbacks: Option>, + max_data_tokens: usize, + map_llm_params: LLMParams, + reduce_llm_params: LLMParams, +} + +pub struct GlobalSearchParams { + pub llm: Box, + pub context_builder: GlobalCommunityContext, + pub num_tokens_fn: fn(&str) -> usize, + pub map_system_prompt: Option, + pub reduce_system_prompt: Option, + pub response_type: String, + pub allow_general_knowledge: bool, + pub general_knowledge_inclusion_prompt: Option, + pub json_mode: bool, + pub callbacks: Option>, + pub max_data_tokens: usize, + pub map_llm_params: LLMParams, + pub reduce_llm_params: LLMParams, + pub context_builder_params: ContextBuilderParams, +} + +impl GlobalSearch { + pub fn new(global_search_params: GlobalSearchParams) -> Self { + let GlobalSearchParams { + llm, + context_builder, + num_tokens_fn, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + json_mode, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + context_builder_params, + } = global_search_params; + + let mut map_llm_params = map_llm_params; + + if json_mode { + map_llm_params + .response_format + .insert("type".to_string(), "json_object".to_string()); + } else { + map_llm_params.response_format.remove("response_format"); + } + + let map_system_prompt = map_system_prompt.unwrap_or(MAP_SYSTEM_PROMPT.to_string()); + let reduce_system_prompt = reduce_system_prompt.unwrap_or(REDUCE_SYSTEM_PROMPT.to_string()); + let general_knowledge_inclusion_prompt = + general_knowledge_inclusion_prompt.unwrap_or(GENERAL_KNOWLEDGE_INSTRUCTION.to_string()); + + GlobalSearch { + llm, + context_builder, + num_tokens_fn, + context_builder_params, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + } + } + + pub async fn asearch( + &self, + query: String, + _conversation_history: Option, + ) -> anyhow::Result { + // Step 1: Generate answers for each batch of community short summaries + let start_time = Instant::now(); + let (context_chunks, context_records) = self + .context_builder + .build_context(self.context_builder_params.clone()) + .await?; + + let mut callbacks = match &self.callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_start(context_chunks.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) + } + None => None, + }; + + let map_responses: Vec<_> = join_all( + context_chunks + .iter() + .map(|data| self._map_response_single_batch(data, &query, self.map_llm_params.clone())), + ) + .await; + + let map_responses: Result, _> = map_responses.into_iter().collect(); + let map_responses = map_responses?; + + callbacks = match &callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_end(map_responses.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) + } + None => None, + }; + + let map_llm_calls: usize = map_responses.iter().map(|response| response.llm_calls).sum(); + let map_prompt_tokens: usize = map_responses.iter().map(|response| response.prompt_tokens).sum(); + + // Step 2: Combine the intermediate answers from step 2 to generate the final answer + let reduce_response = self + ._reduce_response(map_responses.clone(), &query, callbacks, self.reduce_llm_params.clone()) + .await?; + + Ok(GlobalSearchResult { + response: reduce_response.response, + context_data: ContextData::Dictionary(context_records), + context_text: ContextText::Strings(context_chunks), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: map_llm_calls + reduce_response.llm_calls, + prompt_tokens: map_prompt_tokens + reduce_response.prompt_tokens, + map_responses, + reduce_context_data: reduce_response.context_data, + reduce_context_text: reduce_response.context_text, + }) + } + + async fn _map_response_single_batch( + &self, + context_data: &str, + query: &str, + llm_params: LLMParams, + ) -> anyhow::Result { + let start_time = Instant::now(); + let search_prompt = self.map_system_prompt.replace("{context_data}", context_data); + + let mut search_messages = Vec::new(); + search_messages.push(HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ])); + search_messages.push(HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ])); + + let search_response = self + .llm + .agenerate( + MessageType::Dictionary(search_messages), + false, + None, + llm_params, + Some(GlobalSearchPhase::Map), + ) + .await?; + + let processed_response = self.parse_search_response(&search_response); + + Ok(SearchResult { + response: ResponseType::KeyPoints(processed_response), + context_data: ContextData::String(context_data.to_string()), + context_text: ContextText::String(context_data.to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: (self.num_tokens_fn)(&search_prompt), + }) + } + + fn parse_search_response(&self, search_response: &str) -> Vec { + let parsed_elements: Value = serde_json::from_str(search_response).unwrap_or_default(); + + if let Some(points) = parsed_elements.get("points") { + if let Some(points) = points.as_array() { + return points + .iter() + .filter(|element| element.get("description").is_some() && element.get("score").is_some()) + .map(|element| KeyPoint { + answer: element + .get("description") + .unwrap_or(&Value::String("".to_string())) + .to_string(), + score: element + .get("score") + .unwrap_or(&Value::Number(serde_json::Number::from(0))) + .as_i64() + .unwrap_or(0) as i32, + }) + .collect::>(); + } + } + + vec![KeyPoint { + answer: "".to_string(), + score: 0, + }] + } + + async fn _reduce_response( + &self, + map_responses: Vec, + query: &str, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + let start_time = Instant::now(); + let mut key_points: Vec> = Vec::new(); + + for (index, response) in map_responses.iter().enumerate() { + if let ResponseType::KeyPoints(response_list) = &response.response { + for key_point in response_list { + let mut point = HashMap::new(); + point.insert("analyst".to_string(), (index + 1).to_string()); + point.insert("answer".to_string(), key_point.answer.clone()); + point.insert("score".to_string(), key_point.score.to_string()); + key_points.push(point); + } + } + } + + let filtered_key_points: Vec> = key_points + .into_iter() + .filter(|point| point.get("score").unwrap().parse::().unwrap() > 0) + .collect(); + + if filtered_key_points.is_empty() && !self.allow_general_knowledge { + eprintln!("Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."); + + return Ok(SearchResult { + response: ResponseType::String(NO_DATA_ANSWER.to_string()), + context_data: ContextData::String("".to_string()), + context_text: ContextText::String("".to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 0, + prompt_tokens: 0, + }); + } + + let mut sorted_key_points = filtered_key_points; + sorted_key_points.sort_by(|a, b| { + b.get("score") + .unwrap() + .parse::() + .unwrap() + .cmp(&a.get("score").unwrap().parse::().unwrap()) + }); + + let mut data: Vec = Vec::new(); + let mut total_tokens = 0; + for point in sorted_key_points { + let mut formatted_response_data: Vec = Vec::new(); + formatted_response_data.push(format!("----Analyst {}----", point.get("analyst").unwrap())); + formatted_response_data.push(format!("Importance Score: {}", point.get("score").unwrap())); + formatted_response_data.push(point.get("answer").unwrap().to_string()); + let formatted_response_text = formatted_response_data.join("\n"); + + if total_tokens + (self.num_tokens_fn)(&formatted_response_text) > self.max_data_tokens { + break; + } + + data.push(formatted_response_text.clone()); + total_tokens += (self.num_tokens_fn)(&formatted_response_text); + } + let text_data = data.join("\n\n"); + + let search_prompt = format!( + "{}\n{}", + self.reduce_system_prompt + .replace("{report_data}", &text_data) + .replace("{response_type}", &self.response_type), + if self.allow_general_knowledge { + self.general_knowledge_inclusion_prompt.clone() + } else { + String::new() + } + ); + + let search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ]), + ]; + + let llm_callbacks = match callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + llm_callbacks.push(BaseLLMCallback { + response: callback.response.clone(), + }); + } + Some(llm_callbacks) + } + None => None, + }; + + let search_response = self + .llm + .agenerate( + MessageType::Dictionary(search_messages), + true, + llm_callbacks, + llm_params, + Some(GlobalSearchPhase::Reduce), + ) + .await?; + + Ok(SearchResult { + response: ResponseType::String(search_response), + context_data: ContextData::String(text_data.clone()), + context_text: ContextText::String(text_data), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: (self.num_tokens_fn)(&search_prompt), + }) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs new file mode 100644 index 000000000..79f16f1e0 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs @@ -0,0 +1,2 @@ +pub mod global_search; +pub mod prompts; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs new file mode 100644 index 000000000..7c9fef5cb --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs @@ -0,0 +1,164 @@ +// Copyright (c) 2024 Microsoft Corporation. +// Licensed under the MIT License + +// System prompts for global search. + +pub const MAP_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + + +---Data tables--- + +{context_data} + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} +"#; + +pub const REDUCE_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Analyst Reports--- + +{report_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +"#; + +pub const NO_DATA_ANSWER: &str = "I am sorry but I am unable to answer this question given the provided data."; + +pub const GENERAL_KNOWLEDGE_INSTRUCTION: &str = r#" +The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: +"This is an example sentence supported by real-world knowledge [LLM: verify]." +"#; diff --git a/shinkai-libs/shinkai-graphrag/src/search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/mod.rs new file mode 100644 index 000000000..a12441830 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/mod.rs @@ -0,0 +1 @@ +pub mod global_search; diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs new file mode 100644 index 000000000..5c888a8df --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -0,0 +1,211 @@ +use polars::{io::SerReader, prelude::ParquetReader}; +use shinkai_graphrag::{ + context_builder::{ + community_context::GlobalCommunityContext, context_builder::ContextBuilderParams, + indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports, + }, + llm::llm::LLMParams, + search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, +}; +use utils::{ + ollama::Ollama, + openai::{num_tokens, ChatOpenAI}, +}; + +mod utils; + +// #[tokio::test] +async fn ollama_global_search_test() -> Result<(), Box> { + let base_url = "http://localhost:11434"; + let model_type = "llama3.1"; + + let llm = Ollama::new(base_url.to_string(), model_type.to_string()); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + // Using tiktoken for token count estimation + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); + + let context_builder_params = ContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 5000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + // LLM params are ignored for Ollama + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + response_format: std::collections::HashMap::new(), + }; + + // Perform global search + + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), + context_builder, + num_tokens_fn: num_tokens, + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + json_mode: true, + callbacks: None, + max_data_tokens: 5000, + map_llm_params, + reduce_llm_params, + context_builder_params, + }); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} + +// #[tokio::test] +async fn openai_global_search_test() -> Result<(), Box> { + let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); + let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); + + let llm = ChatOpenAI::new(Some(api_key), llm_model, 5); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); + + let context_builder_params = ContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 12_000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + response_format: std::collections::HashMap::new(), + }; + + // Perform global search + + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), + context_builder, + num_tokens_fn: num_tokens, + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + json_mode: true, + callbacks: None, + max_data_tokens: 12_000, + map_llm_params, + reduce_llm_params, + context_builder_params, + }); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs new file mode 100644 index 000000000..3ef32f620 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod ollama; +pub mod openai; diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs new file mode 100644 index 000000000..41d3619b8 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -0,0 +1,100 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct OllamaResponse { + pub model: String, + pub created_at: String, + pub message: OllamaMessage, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct OllamaMessage { + pub role: String, + pub content: String, +} + +pub struct Ollama { + base_url: String, + model_type: String, +} + +impl Ollama { + pub fn new(base_url: String, model_type: String) -> Self { + Ollama { base_url, model_type } + } +} + +#[async_trait] +impl BaseLLM for Ollama { + async fn agenerate( + &self, + messages: MessageType, + _streaming: bool, + _callbacks: Option>, + _llm_params: LLMParams, + search_phase: Option, + ) -> anyhow::Result { + let client = Client::new(); + let chat_url = format!("{}{}", &self.base_url, "/api/chat"); + + let messages_json = match messages { + MessageType::String(message) => json![message], + MessageType::Strings(messages) => json!(messages), + MessageType::Dictionary(messages) => { + let messages = match search_phase { + Some(GlobalSearchPhase::Map) => { + // Filter out system messages and convert them to user messages + messages + .into_iter() + .filter(|map| map.get_key_value("role").is_some_and(|(_, v)| v == "system")) + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + Some(GlobalSearchPhase::Reduce) => { + // Convert roles to user + messages + .into_iter() + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + _ => messages, + }; + + json!(messages) + } + }; + + let payload = json!({ + "model": self.model_type, + "messages": messages_json, + "stream": false, + }); + + let response = client.post(chat_url).json(&payload).send().await?; + let response = response.json::().await?; + + Ok(response.message.content) + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs new file mode 100644 index 000000000..255d5b4e5 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -0,0 +1,146 @@ +use async_openai::{ + config::OpenAIConfig, + types::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionResponseFormat, + ChatCompletionResponseFormatType, CreateChatCompletionRequestArgs, + }, + Client, +}; +use async_trait::async_trait; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; + +pub struct ChatOpenAI { + pub api_key: Option, + pub model: String, + pub max_retries: usize, +} + +impl ChatOpenAI { + pub fn new(api_key: Option, model: String, max_retries: usize) -> Self { + ChatOpenAI { + api_key, + model, + max_retries, + } + } + + pub async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + let mut retry_count = 0; + + loop { + match self + ._agenerate(messages.clone(), streaming, callbacks.clone(), llm_params.clone()) + .await + { + Ok(response) => return Ok(response), + Err(e) => { + if retry_count < self.max_retries { + retry_count += 1; + continue; + } + return Err(e); + } + } + } + } + + async fn _agenerate( + &self, + messages: MessageType, + _streaming: bool, + _callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + let client = match &self.api_key { + Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), + None => Client::new(), + }; + + let messages = match messages { + MessageType::String(message) => vec![message], + MessageType::Strings(messages) => messages, + MessageType::Dictionary(messages) => { + let messages = messages + .iter() + .map(|message_map| { + message_map + .iter() + .map(|(key, value)| format!("{}: {}", key, value)) + .collect::>() + .join("\n") + }) + .collect(); + messages + } + }; + + let request_messages = messages + .into_iter() + .map(|m| ChatCompletionRequestSystemMessageArgs::default().content(m).build()) + .collect::>(); + + let request_messages: Result, _> = request_messages.into_iter().collect(); + let request_messages = request_messages?; + let request_messages = request_messages + .into_iter() + .map(|m| Into::::into(m.clone())) + .collect::>(); + + let _response_format = if llm_params + .response_format + .get_key_value("type") + .is_some_and(|(_k, v)| v == "json_object") + { + ChatCompletionResponseFormat { + r#type: ChatCompletionResponseFormatType::JsonObject, + } + } else { + ChatCompletionResponseFormat { + r#type: ChatCompletionResponseFormatType::Text, + } + }; + + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(llm_params.max_tokens) + .temperature(llm_params.temperature) + //.response_format(response_format) + .model(self.model.clone()) + .messages(request_messages) + .build()?; + + let response = client.chat().create(request).await?; + + if let Some(choice) = response.choices.get(0) { + return Ok(choice.message.content.clone().unwrap_or_default()); + } + + return Ok(String::new()); + } +} + +#[async_trait] +impl BaseLLM for ChatOpenAI { + async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + _search_phase: Option, + ) -> anyhow::Result { + self.agenerate(messages, streaming, callbacks, llm_params).await + } +} + +pub fn num_tokens(text: &str) -> usize { + let token_encoder = Tokenizer::Cl100kBase; + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + bpe.encode_with_special_tokens(text).len() +}